I’m looping over groups in a PySpark dataframe and do one filter operation, several joins (depending on depth of group) and one union operation on each group. The individual groups are quite small, in my real-world use cases number of rows for each group ranges from 3-20. I have around 1500 groups to loop thru and it takes very long time.
I run this on Databricks 14.3, driver: 64 GB,8 workers.
I’m interested in how I should think in terms of optimization. There are lot of recommendations online – broadcasting, cache etc. But I found it hard to know when to use what and when.
How can the code snippet below be optimized? How would a Spark-developer think?
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.appName("OPT").getOrCreate()
data = [
("A", 1, 0, 2121),
("A", 2, 2121, 5567),
("A", 3, 5567, 5566),
("A", 3, 5567, 5568),
("A", 3, 5567, 5569),
("A", 3, 5567, 5570),
("B", 1, 0, 3331),
("B", 2, 3331, 5515),
]
columns = ["group_id", "level", "parent", "node"]
test_df = spark.createDataFrame(data, columns)
test_df = test_df.withColumn("path", F.array("parent"))
# Create list to iterate over
list_to_iterate = test_df.groupBy("group_id").agg(F.max("level").alias("depth")).collect()
# Empty dataframe to store result from loop
new_result_df = spark.createDataFrame([], schema=test_df.schema)
for group in list_to_iterate:
current_level = group['depth']
tmp=test_df.filter(col('group_id')==group['group_id'])
original_group = tmp
while current_level > 1:
# Repeatedly join operation
joined_df = tmp.alias("child").join(
original_group.alias("parent"),
F.col("child.parent") == F.col("parent.node"),
"left"
).select(
F.col("child.group_id"),
F.col("child.level"),
F.col("parent.parent").alias("parent"),
F.col("child.node"),
# Append operation
F.expr("CASE WHEN parent.parent IS NOT NULL THEN array_union(child.path, array(parent.parent)) ELSE child.path END").alias("path")
)
tmp = joined_df
current_level -= 1
# Union operation
new_result_df = new_result_df.union(joined_df)
new_result_df.show(truncate=False)