I am trying to create row numbers for my data that is distributed across multiple partitions and possibly on different nodes of a cluster.
Below is the code that I used to divide the data into multiple partitions based on the Data column
from pyspark.sql import Window, Row
from pyspark.sql.functions import col, row_number, sum, spark_partition_id, monotonically_increasing_id, min, max
from pyspark.sql.functions import broadcast
empDF = spark.read.csv("dbfs:/FileStore/large_data.csv", header=True)
empDFPartId = empDF.repartitionByRange(4,col("Data")).withColumn("partition_id", spark_partition_id()).cache()
partitionsDF = empDFPartId
.select("partition_id")
.groupBy("partition_id")
.count()
.withColumn(
"count",
sum("count").over(Window.orderBy("partition_id")) - col("count").cast("int")
)
#partitionsDF.show()
empDFRowNumber = empDFPartId
.join(
partitionsDF,
["partition_id"]
)
.withColumn(
"row_number_within_partition",
row_number().over(Window.partitionBy("partition_id").orderBy("Data"))
)
.withColumn(
"row_number",
col("count") + col("row_number_within_partition")
)
empDFRowNumber
.write
.mode("overwrite")
.option("header", "true")
.csv("dbfs:/FileStore/partition_1")
This code gave me a global row number but there were gaps in the sequence. Below is the output of the min and max row number present in each partition
partition_id | min | max | count |
---|---|---|---|
0 | 1 | 999999 | 4291703 |
1 | 4291704 | 8584071 | 4292368 |
2 | 10000000 | 9999999 | 4292208 |
3 | 12876280 | 17169548 | 4293269 |
Does anyone know if there is any other way ?