I’m relatively new to scala and spark programming.
I have a use case where I need to groupby data based on certain columns and have a count of a certain column (using pivot) and then finally I need to create a nested dataframe out of my flat dataframe.
One major challenge I am facing is I need to retain certain other columns as well (not the one I am pivoting on).
I’m not able to figure out an efficient way to do it.
INPUT
ID ID2 ID3 country items_purchased quantity 1 1 1 UK apple 1 1 1 1 USA mango 1 1 2 3 China banana 3 2 1 1 UK mango 1
Now say, I want to pivot on ‘country’ and group by on (‘ID’,’ID2′,’ID3′) But I also want to maintain the other columns as a list.
For instance,
OUTPUT-1 :
ID ID2 ID3 UK USA China items_purchased quantity 1 1 1 1 1 0 [apple,mango] [1,1] 1 2 3 0 0 1 [banana] [3] 2 1 1 1 0 0 [mango] [1]
Once I achieve this,
I want to nest it into a nested structure such that my schema looks like :
{ "ID" : 1, "ID2" : 1, "ID3" : 1, "countries" : { "UK" : 1, "USA" : 1, "China" : 0, }, "items_purchased" : ["apple", "mango"], "quantity" : [1,1] }
I believe I can use a case class and then map every row of the dataframe to it. However, I am not sure if that is an efficient way.I would love to know if there is a more optimised way to achieve this.
What I have in mind is something on these lines :
dataframe.map(row => myCaseClass(row.getAs[Long]("ID"), row.getAs[Long]("ID2"), row.getAs[Long]("ID3"), CountriesCaseClass( row.getAs[String]("UK") ) )
and so on…
Advertisement
Answer
I think this should work for your case. The partitions number is calculated from the formula partitions_num = data_size / 500MB
.
import org.apache.spark.sql.functions.{collect_list, count, col, lit, map} val data = Seq( (1, 1, 1, "UK", "apple", 1), (1, 1, 1, "USA","mango", 1), (1, 2, 3, "China", "banana", 3), (2, 1, 1, "UK", "mango", 1)) // e.g: partitions_num = 100GB / 500MB = 200, adjust it according to the size of your data val partitions_num = 250 val df = data.toDF("ID", "ID2", "ID3", "country", "items_purchased", "quantity") .repartition(partitions_num, $"ID", $"ID2", $"ID3") //the partition should remain the same for all the operations .persist() //get countries, we will need it to fill with 0 the null values after pivoting, for the mapping and for the drop val countries = df.select("country").distinct.collect.map{_.getString(0)} //creates a sequence of key/value which should be the input for the map function val countryMapping = countries.flatMap{c => Seq(lit(c), col(c))} val pivotCountriesDF = df.select("ID", "ID2", "ID3", "country") .groupBy("ID", "ID2", "ID3") .pivot($"country") .count() .na.fill(0, countries) .withColumn("countries", map(countryMapping:_*))//i.e map("UK", col("UK"), "China", col("China")) -> {"UK":0, "China":1} .drop(countries:_*) // pivotCountriesDF.rdd.getNumPartitions == 250, Spark will retain the partition number since we didnt change the partition key // +---+---+---+-------------------------------+ // |ID |ID2|ID3|countries | // +---+---+---+-------------------------------+ // |1 |2 |3 |[China -> 1, USA -> 0, UK -> 0]| // |1 |1 |1 |[China -> 0, USA -> 1, UK -> 1]| // |2 |1 |1 |[China -> 0, USA -> 0, UK -> 1]| // +---+---+---+-------------------------------+ val listDF = df.select("ID", "ID2", "ID3", "items_purchased", "quantity") .groupBy("ID", "ID2", "ID3") .agg( collect_list("items_purchased").as("items_purchased"), collect_list("quantity").as("quantity")) // +---+---+---+---------------+--------+ // |ID |ID2|ID3|items_purchased|quantity| // +---+---+---+---------------+--------+ // |1 |2 |3 |[banana] |[3] | // |1 |1 |1 |[apple, mango] |[1, 1] | // |2 |1 |1 |[mango] |[1] | // +---+---+---+---------------+--------+ // listDF.rdd.getNumPartitions == 250, to validate this try to change the partition key with .groupBy("ID", "ID2") it will fall back to the default 200 value of spark.sql.shuffle.partitions setting val joinedDF = pivotCountriesDF.join(listDF, Seq("ID", "ID2", "ID3")) // joinedDF.rdd.getNumPartitions == 250, the same partitions will be used for the join as well. // +---+---+---+-------------------------------+---------------+--------+ // |ID |ID2|ID3|countries |items_purchased|quantity| // +---+---+---+-------------------------------+---------------+--------+ // |1 |2 |3 |[China -> 1, USA -> 0, UK -> 0]|[banana] |[3] | // |1 |1 |1 |[China -> 0, USA -> 1, UK -> 1]|[apple, mango] |[1, 1] | // |2 |1 |1 |[China -> 0, USA -> 0, UK -> 1]|[mango] |[1] | // +---+---+---+-------------------------------+---------------+--------+ joinedDF.toJSON.show(false) // +--------------------------------------------------------------------------------------------------------------------+ // |value | // +--------------------------------------------------------------------------------------------------------------------+ // |{"ID":1,"ID2":2,"ID3":3,"countries":{"China":1,"USA":0,"UK":0},"items_purchased":["banana"],"quantity":[3]} | // |{"ID":1,"ID2":1,"ID3":1,"countries":{"China":0,"USA":1,"UK":1},"items_purchased":["apple","mango"],"quantity":[1,1]}| // |{"ID":2,"ID2":1,"ID3":1,"countries":{"China":0,"USA":0,"UK":1},"items_purchased":["mango"],"quantity":[1]} | // +--------------------------------------------------------------------------------------------------------------------++
Good luck and let me know if you need any clarification.