In the data below, for each id2, I want to collect a list of the id1 that is above them in hierarchy/level.
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
schema = StructType([
StructField("group_id", StringType(), False),
StructField("level", IntegerType(), False),
StructField("id1", IntegerType(), False),
StructField("id2", IntegerType(), False)
])
# Feature values
levels = [1, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
id1_values = [0, 200001, 677555, 677555, 677555, 677555, 677555, 605026, 605026, 605026, 605026, 605026, 605026, 662867, 662867, 662867, 662867, 662867]
id2_values = [200001, 677555, 605026, 662867, 676423, 659933, 660206, 675767, 681116, 913248,
910758, 913773, 698738, 910387, 910758, 910387, 910113, 910657]
data = zip(['A'] * len(levels), levels, id1_values, id2_values)
# Create DataFrame
data = spark.createDataFrame(data, schema)
This can be done like this, using a window function and collect_list.
window = Window.partitionBy('group_id').orderBy('level').rowsBetween(Window.unboundedPreceding, Window.currentRow)
data.withColumn("list_id1", F.collect_list("id1").over(window)).display()
The issue is in some cases there are several id1s with the same level. I want the collect_list to take this into account in some way.
As an example, on level two we have two unique id1s, 605026 and 662867.For id2 910387, that corresponds to id1 662867 on level 4. I don’t want to include 605026 in the list.
The list I want to collect should only include one id1 per level.
How can this be achieved using PySpark API?