Say I have a dataframe containing magazine subscription information:
subscription_id user_id created_at expiration_date 12384 1 2018-08-10 2018-12-10 83294 1 2018-06-03 2018-10-03 98234 1 2018-04-08 2018-08-08 24903 2 2018-05-08 2018-07-08 32843 2 2018-03-25 2018-05-25 09283 2 2018-04-07 2018-06-07
Now I want to add a column that states how many previous subscriptions a user had that expired before this current subscription began. In other words, how many expiration dates associated with a given user were before this subscription’s start date. Here is the full desired output:
subscription_id user_id created_at expiration_date previous_expired 12384 1 2018-08-10 2018-12-10 1 83294 1 2018-06-03 2018-10-03 0 98234 1 2018-04-08 2018-08-08 0 24903 2 2018-05-08 2018-07-08 2 32843 2 2018-03-25 2018-05-03 1 09283 2 2018-01-25 2018-02-25 0
EDIT: Tried a variety of lag/lead/etc using Python and I’m now thinking this is a SQL problem
df = df.withColumn('shiftlag', func.lag(df.expires_at).over(Window.partitionBy('user_id').orderBy('created_at')))
<— EDIT, EDIT: Never mind, this doesn’t work
I think I exhausted the lag/lead/shift method and found it doesn’t work. I’m now thinkings it would be best to do this using Spark SQL, perhaps with a case when
to produce the new column, combined with a having
, grouped by ID?
Figured it out using PySpark:
I first created another column with an array of all expiration dates for each user:
joined_array = df.groupBy('user_id').agg(collect_set('expiration_date'))
Then joined that array back to the original dataframe:
joined_array = joined_array.toDF('user_idDROP', 'expiration_date_array') df = df.join(joined_array, df.user_id == joined_array.user_idDROP, how = 'left').drop('user_idDROP')
Then created a function to iterate through array and add 1 to the count if the created date is greater than the expiration date:
def check_expiration_count(created_at, expiration_array): if not expiration_array: return 0 else: count = 0 for i in expiration_array: if created_at > i: count += 1 return count check_expiration_count = udf(check_expiration_count, IntegerType())
Then applied that function to create a new column with the correct count:
df = df.withColumn('count_of_subs_ending_before_creation', check_expiration_count(df.created_at, df.expiration_array))
Wala. Done. Thanks everyone (nobody helped but thanks anyway). Hope someone finds this useful in 2022