概要
PySParkで行に0〜nまでの連続する数値を採番したかった。
バージョン情報
- spark-2.3.1
- Python 3.5.5
サンプルデータ
下記のような2つのカラムを持つCSVファイル(100万行)を利用。
$ gunzip -c foo.csv.gz | head -5 0,0.194617 1,0.184299 2,0.988041 3,0.258601 4,0.782715
1つ目のカラムは連番。2つ目のカラムはランダム。
生成コード。ファイルサイズは5.6MB。
import gzip, random, tqdm with gzip.open('foo.csv.gz', 'wt') as w: for i in tqdm.tqdm(range(1000000)): w.write('%d,%f\n' %(i, random.random()))
サンプルデータの読込み
カラム名をid_a, valueとしてcsvファイルを読み込む。
from pyspark.sql import types as T, functions as F # カラムの型と名前を定義 schema = T.StructType([ T.StructField('id_a', T.IntegerType()), T.StructField('value', T.DoubleType()) ]) # 読込み df = spark.read \ .schema(schema) \ .csv('foo.csv.gz') # 5行表示 df.take(5)
実行結果
[Row(id_a=0, value=0.194617), Row(id_a=1, value=0.184299), Row(id_a=2, value=0.988041), Row(id_a=3, value=0.258601), Row(id_a=4, value=0.782715)]
zipWithIndexの利用
RDDにzipWithIndexがあるのでそれを利用すると連番が作れる。
df.rdd.zipWithIndex().take(5)
実行結果
[(Row(id_a=0, value=0.194617), 0), (Row(id_a=1, value=0.184299), 1), (Row(id_a=2, value=0.988041), 2), (Row(id_a=3, value=0.258601), 3), (Row(id_a=4, value=0.782715), 4)]
Rowと連番のTupleが生成されている。
採番されたIDをid_bという名前にしてDataFrameにしてみる。
# rowとidxのtupleを受け取って、idxをカラムに持つrowを作る関数 def new_row(row_idx): row, idx = row_idx dic = row.asDict() dic['id_b'] = idx return T.Row(**dic) # zipWithIndexしてmapで上の関数を呼び出す df2 = df.rdd.zipWithIndex() \ .map(new_row) \ .toDF() df2.take(5)
実行結果
[Row(id_a=0, id_b=0, value=0.194617), Row(id_a=1, id_b=1, value=0.184299), Row(id_a=2, id_b=2, value=0.988041), Row(id_a=3, id_b=3, value=0.258601), Row(id_a=4, id_b=4, value=0.782715)]
monotonically_increasing_idの利用
zipWithIndexだとrddにする必要があるが、monotonically_increasing_idだとDataFrameに対して直接実行できるのでより簡単。
from pyspark.sql.functions import monotonically_increasing_id df2 = df.withColumn('id_b', monotonically_increasing_id()) df2.take(5)
実行結果
[Row(id_a=0, value=0.194617, id_b=0), Row(id_a=1, value=0.184299, id_b=1), Row(id_a=2, value=0.988041, id_b=2), Row(id_a=3, value=0.258601, id_b=3), Row(id_a=4, value=0.782715, id_b=4)]
但し難点もあって、並列で動かすとIDが連番ではなくなる。
df2 = df.repartition(50) \ .withColumn('id_b', monotonically_increasing_id()) df2.agg(F.max('id_b')).show()
実行結果。行数は100万なので連番の最大は99万9999になるはずだが。
+------------+ | max(id_b)| +------------+ |420906815007| +------------+
振られる数値はユニークにはなるので連番である必要がなければこちらの方が便利。
zipWithIndexであればrepartitionした場合でも連番になっている。
df2 = df.repartition(50) \ .rdd.zipWithIndex() \ .map(new_row) \ .toDF() df2.agg(F.max('id_b')).show()
+---------+ |max(id_b)| +---------+ | 999999| +---------+
改定履歴
Author: Masato Watanabe, Date: 2019-02-07, 記事投稿