Given a Polars DataFrame
data = pl.DataFrame({"user_id": [1, 1, 1, 1, 1, 2, 2, 2, 2], "event": [False, True, True, False, True, True, True, False, False]
I wish to calculate a column event_chain
which counts the streak of times where a user has an event, where in any of the previous 4 rows they also had an event. Every time a new event happens, when the user already has a streak active, the streak counter is incremented, it should be then set to zero if they don’t have another event for another 4 rows
user_id | event | event_chain |
---|---|---|
1 | False | 0 |
1 | True | 0 |
1 | True | 1 |
1 | False | 1 |
1 | True | 2 |
2 | True | 0 |
2 | True | 1 |
2 | False | 0 |
2 | False | 0 |
I have working code as follows to do this, but I think there should be a cleaner way to do it
data.with_columns(
rows_since_last_event=pl.int_range(pl.len()).over("user_id")
- pl.when("event").then(pl.int_range(pl.len())).forward_fill()
.over("user_id"),
rows_till_next_event=pl.when("event").then(pl.int_range(pl.len()))
.backward_fill().over("user_id") - pl.int_range(pl.len()).over("athlete_id")
)
.with_columns(
chain_event=pl.when(
pl.col("event")
.fill_null(0)
.rolling_sum(window_size=4, min_periods=1)
.over("user_id")
- pl.col("event").fill_null(0)
> 0
)
.then(1)
.otherwise(0)
)
.with_columns(
chain_event_change=pl.when(
pl.col("chain_event").eq(1),
pl.col("chain_event").shift().eq(0),
pl.col("rows_since_last_event").fill_null(5) > 3,
)
.then(1)
.when(
pl.col("congested_event").eq(0),
pl.col("congested_event").shift().eq(1),
pl.col("rows_till_next_event").fill_null(5) > 3,
)
.then(1)
.otherwise(0)
)
.with_columns(
chain_event_identifier=pl.col("chain_event_change")
.cum_sum()
.over("user_id")
)
.with_columns(
event_chain=pl.col("chain_event")
.cum_sum()
.over("user_id", "chain_event_identifier")
)
)