I need to element-wise sum ALL DenseVectors for each sentence. There can be multiple DenseVectors (per sentence) in each row’s cell, i.e. embeddings.
I’ll explain with the help of multiple lists in a cell. Take the following example:
+----------------------------------+
|abc |
+----------------------------------+
|[[1.0, 2.0, 3.0], [1.2, 2.2, 3.2]]|
|[[2.0, 3.0, 4.0], [3.1, 4.1, 5.1]]|
|[[3.0, 4.0, 5.0], [4.3, 5.3, 6.3]]|
|[[4.0, 5.0, 6.0], [5.2, 4.2, 7.2]]|
+----------------------------------+
The output required is:
+-----------------+
|abc |
+-----------------+
|[2.2, 4.2, 6.2]|
|[5.1, 7.1, 9.1]|
|[7.3, 9.3, 11.3]|
|[9.2, 9.2, 13.2]|
+-----------------+
I’ve tried a few ways on the following lines but nothing seems to work.
schema = StructType([
StructField("abc", ArrayType(ArrayType(FloatType())), True)
])
test_data = [
Row([[1.0, 2.0, 3.0], [1.2, 2.2, 3.2]]),
Row([[2.0, 3.0, 4.0], [3.1, 4.1, 5.1]]),
Row([[3.0, 4.0, 5.0], [4.3, 5.3, 6.3]]),
Row([[4.0, 5.0, 6.0], [5.2, 4.2, 7.2]])
]
test_df = spark.createDataFrame(data=test_data, schema=schema)
test_df.groupBy("abc").agg(array([sum(col("abc")[i].getItem(j)) for j in range(3) for i in range(len(test_df.select('abc').first()[0]))]).alias("sum")).show(truncate=False)
test_df.groupBy("abc").agg(array(*[array(*[sum(col("abc")[i].getItem(j)) for j in range(3) for i in range(len(test_df.select('abc').first()[0]))])]).alias("sum")).show(truncate=False)
test_df.groupBy("abc").agg(array(*[array(*[sum(col("abc")[i].getItem(j)) for i in range(len(test_df.select('abc').first()[0]))]) for j in range(3) ]).alias("sum")).show(truncate=False)