Skip to content
Advertisement

Spark SQL Partition By, Window, Order By, Count

Say I have a dataframe containing magazine subscription information:

subscription_id    user_id       created_at       expiration_date
 12384               1           2018-08-10        2018-12-10
 83294               1           2018-06-03        2018-10-03
 98234               1           2018-04-08        2018-08-08
 24903               2           2018-05-08        2018-07-08
 32843               2           2018-03-25        2018-05-25
 09283               2           2018-04-07        2018-06-07

Now I want to add a column that states how many previous subscriptions a user had that expired before this current subscription began. In other words, how many expiration dates associated with a given user were before this subscription’s start date. Here is the full desired output:

subscription_id    user_id       created_at       expiration_date   previous_expired
 12384               1           2018-08-10        2018-12-10          1
 83294               1           2018-06-03        2018-10-03          0
 98234               1           2018-04-08        2018-08-08          0
 24903               2           2018-05-08        2018-07-08          2
 32843               2           2018-03-25        2018-05-03          1
 09283               2           2018-01-25        2018-02-25          0

Attempts:

EDIT: Tried a variety of lag/lead/etc using Python and I’m now thinking this is a SQL problem

df = df.withColumn('shiftlag', func.lag(df.expires_at).over(Window.partitionBy('user_id').orderBy('created_at')))

<— EDIT, EDIT: Never mind, this doesn’t work

I think I exhausted the lag/lead/shift method and found it doesn’t work. I’m now thinkings it would be best to do this using Spark SQL, perhaps with a case when to produce the new column, combined with a having count, grouped by ID?

Advertisement

Answer

Figured it out using PySpark:

I first created another column with an array of all expiration dates for each user:

joined_array = df.groupBy('user_id').agg(collect_set('expiration_date'))

Then joined that array back to the original dataframe:

joined_array = joined_array.toDF('user_idDROP', 'expiration_date_array')
df = df.join(joined_array, df.user_id == joined_array.user_idDROP, how = 'left').drop('user_idDROP')

Then created a function to iterate through array and add 1 to the count if the created date is greater than the expiration date:

def check_expiration_count(created_at, expiration_array):
  if not expiration_array:
    return 0
  else:
   count = 0
    for i in expiration_array:
  if created_at > i:
    count += 1
return count

check_expiration_count = udf(check_expiration_count, IntegerType())

Then applied that function to create a new column with the correct count:

df = df.withColumn('count_of_subs_ending_before_creation', check_expiration_count(df.created_at, df.expiration_array))

Wala. Done. Thanks everyone (nobody helped but thanks anyway). Hope someone finds this useful in 2022

User contributions licensed under: CC BY-SA
2 People found this is helpful
Advertisement