guys, all good?
I need some help to do cumulative sum of a sumproduct, i have a db like that:
dfz = spark.createDataFrame( [
(202401, 20240101, 20, 0, 0.1, 20.0),
(202401, 20240102, 50, 20, 0.2, 50.0),
(202401, 20240103, 20, 50, 0.2, 20.0),
(202401, 20240104, None, 20, 0.2, None),
(202401, 20240105, None, 0, 0.3, None),
(202401, 20240106, None, 0, 0.1, None)],
["year_month", "date", "amount", "amount_lag", "perc_prev", "amount_prev"] )
+----------+--------+------+----------+---------+-----------+
|year_month| date|amount|amount_lag|perc_prev|amount_prev|
+----------+--------+------+----------+---------+-----------+
| 202401|20240101| 20| 0| 0.1| 20.0|
| 202401|20240102| 50| 20| 0.2| 50.0|
| 202401|20240103| 20| 50| 0.2| 20.0|
| 202401|20240104| NULL| 20| 0.2| NULL|
| 202401|20240105| NULL| NULL| 0.3| NULL|
| 202401|20240106| NULL| NULL| 0.1| NULL|
+----------+--------+------+----------+---------+-----------+
So, i need to calculate the amount_prev, which is (“amount_lag” * “perc_prev”) + “amount lag”, the problem is, the amount_prev only calculate when “amount” is equal zero and when “amount_lag” is null we use the lag of the “amount_prev”, the final db will be like that:
dfz = spark.createDataFrame( [
(202401, 20240101, 20, 0, 0.1, 20.0),
(202401, 20240102, 50, 20, 0.2, 50.0),
(202401, 20240103, 20, 50, 0.2, 20.0),
(202401, 20240104, None, 20, 0.2, 24.0),
(202401, 20240105, None, 0, 0.3, 31.2),
(202401, 20240106, None, 0, 0.1, 34.32)],
["year_month", "date", "amount", "amount_lag", "perc_prev", "amount_prev"] )
+----------+--------+------+----------+---------+-----------+
|year_month| date|amount|amount_lag|perc_prev|amount_prev|
+----------+--------+------+----------+---------+-----------+
| 202401|20240101| 20| 0| 0.1| 20.0|
| 202401|20240102| 50| 20| 0.2| 50.0|
| 202401|20240103| 20| 50| 0.2| 20.0|
| 202401|20240104| NULL| 20| 0.2| 24.0|
| 202401|20240105| NULL| 24| 0.3| 31.2|
| 202401|20240106| NULL| 31| 0.1| 34.32|
+----------+--------+------+----------+---------+-----------+
I tried this:
w1 = (
Window.partitionBy("year_month", "date")
.orderBy('date').rangeBetween(Window.unboundedPreceding, 0))
test = (
dbz
.withColumn('amount_prev',
when((~col('amount').isNull()), col('amount'))
.otherwise((col('amount_lag')*col('perc_prev'))+col('amount_lag'))
)
)
But didn’t worked, anyone can help me?
I tried sum by a window, but it didnt cumulative sum by the year_month, i expect a function to calculate a database like the second example