I need to solve a problem where a company wants to offer k different users free use (a kind of coupon) of their application for two months. The goal is to identify users who are likely to churn (leave the system) and select from them the k users that should be retained. A “retention” user is defined as someone who brings a lot of value to the company by frequently listening to music, but has shown a significant decrease in their usage frequency recently.
I want to calculate the average of each user’s last 6 values and compare it to the average of their last 3 values. If the user listened to less than 50% of their overall 6-month average in the last 3 months, they should be flagged as at-risk.
IM FORCED NOT TO USE DATAFRAME, ONLY PYSPARK WITH RDD
Here is my current code:
from datetime import datetime
from pyspark import SparkContext
def k_users_to_retain(k, timestamp):
# Step 1: Filter users who registered before the timestamp and collect their IDs
filtered_users = users_rdd.filter(lambda user: user.registered is not None and user.registered < timestamp)
eligible_user_ids = filtered_users.map(lambda user: user.userid).collect()
# Step 2: Filter tracks for eligible users and calculate usage frequency per user per month/year
filtered_tracks = tracks_rdd.filter(lambda track: track.userid in eligible_user_ids)
.map(lambda track: ((track.userid, track.year, track.month), 1))
.reduceByKey(lambda a, b: a + b)
# Step 3: Group by user and sort by year and month in descending order, then take the top 6 records
grouped_by_user = filtered_tracks.map(lambda x: (x[0][0], (x[0][1], x[0][2], x[1])))
.groupByKey()
.mapValues(lambda records: sorted(records, key=lambda x: (x[0], x[1]), reverse=True)[:6])
# Flatten the result to get the final output format
top_6_per_user = grouped_by_user.flatMap(lambda x: [((x[0], (record[2]))) for record in x[1]])
# Collect the results
result = top_6_per_user.collect()
return result
# Define the inputs
k = 5
timestamp = datetime(2009, 4, 8)
# Get the top K users' tracks sorted by date and retain only the latest 6 observations per user
result = k_users_to_retain(k, timestamp)
# Display the results
for record in result:
print(record)
Here is part of the output:
('user_000001', 62)
('user_000001', 822)
('user_000001', 700)
('user_000001', 671)
('user_000001', 680)
('user_000001', 760)
('user_000002', 486)
('user_000002', 645)
('user_000002', 673)
('user_000002', 791)
('user_000002', 608)
('user_000002', 953)
('user_000003', 351)
('user_000003', 50)
('user_000003', 140)
('user_000003', 401)
('user_000003', 88)
('user_000003', 183)
('user_000004', 22)
('user_000004', 504)
('user_000004', 35)
('user_000004', 39)
('user_000004', 539)
('user_000004', 693)
Yoel Ha is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.