Skip to content
Advertisement

Aggregate data from multiple rows to one and then nest the data

I’m relatively new to scala and spark programming.

I have a use case where I need to groupby data based on certain columns and have a count of a certain column (using pivot) and then finally I need to create a nested dataframe out of my flat dataframe.

One major challenge I am facing is I need to retain certain other columns as well (not the one I am pivoting on).

I’m not able to figure out an efficient way to do it.

INPUT

ID ID2 ID3 country items_purchased quantity
1 1    1     UK    apple           1      
1 1    1     USA   mango           1       
1 2    3     China banana          3       
2 1    1     UK    mango           1       

Now say, I want to pivot on ‘country’ and group by on (‘ID’,’ID2′,’ID3′) But I also want to maintain the other columns as a list.

For instance,

OUTPUT-1 :

ID ID2 ID3 UK USA China items_purchased quantity
1  1   1   1  1    0    [apple,mango]   [1,1] 
1  2   3   0  0    1    [banana]        [3]
2  1   1   1  0    0    [mango]         [1]

Once I achieve this,

I want to nest it into a nested structure such that my schema looks like :

{
"ID"  : 1,
"ID2" : 1,
"ID3" : 1,
"countries" : 
{
"UK" : 1,
"USA" : 1,
"China" : 0,
},
"items_purchased" : ["apple", "mango"],
"quantity" : [1,1]
}

I believe I can use a case class and then map every row of the dataframe to it. However, I am not sure if that is an efficient way.I would love to know if there is a more optimised way to achieve this.

What I have in mind is something on these lines :

dataframe.map(row => myCaseClass(row.getAs[Long]("ID"),
row.getAs[Long]("ID2"),
row.getAs[Long]("ID3"),
CountriesCaseClass(
row.getAs[String]("UK")
)
)

and so on…

Advertisement

Answer

I think this should work for your case. The partitions number is calculated from the formula partitions_num = data_size / 500MB.

import org.apache.spark.sql.functions.{collect_list, count, col, lit, map}

val data = Seq(
(1, 1, 1, "UK", "apple", 1),
(1, 1, 1, "USA","mango", 1),
(1, 2, 3, "China", "banana", 3),
(2, 1, 1, "UK", "mango", 1))

// e.g: partitions_num = 100GB / 500MB = 200, adjust it according to the size of your data
val partitions_num = 250
val df = data.toDF("ID", "ID2", "ID3", "country", "items_purchased", "quantity")
              .repartition(partitions_num, $"ID", $"ID2", $"ID3") //the partition should remain the same for all the operations
              .persist()

//get countries, we will need it to fill with 0 the null values after pivoting, for the mapping and for the drop
val countries = df.select("country").distinct.collect.map{_.getString(0)}

//creates a sequence of key/value which should be the input for the map function
val countryMapping = countries.flatMap{c => Seq(lit(c), col(c))}
val pivotCountriesDF = df.select("ID", "ID2", "ID3", "country")
                          .groupBy("ID", "ID2", "ID3")
                          .pivot($"country")
                          .count()
                          .na.fill(0, countries)
                          .withColumn("countries", map(countryMapping:_*))//i.e map("UK", col("UK"), "China", col("China")) -> {"UK":0, "China":1}
                          .drop(countries:_*)

// pivotCountriesDF.rdd.getNumPartitions == 250, Spark will retain the partition number since we didnt change the partition key

// +---+---+---+-------------------------------+
// |ID |ID2|ID3|countries                      |
// +---+---+---+-------------------------------+
// |1  |2  |3  |[China -> 1, USA -> 0, UK -> 0]|
// |1  |1  |1  |[China -> 0, USA -> 1, UK -> 1]|
// |2  |1  |1  |[China -> 0, USA -> 0, UK -> 1]|
// +---+---+---+-------------------------------+

val listDF = df.select("ID", "ID2", "ID3", "items_purchased", "quantity")
                .groupBy("ID", "ID2", "ID3")
                .agg(
                  collect_list("items_purchased").as("items_purchased"), 
                  collect_list("quantity").as("quantity"))

// +---+---+---+---------------+--------+
// |ID |ID2|ID3|items_purchased|quantity|
// +---+---+---+---------------+--------+
// |1  |2  |3  |[banana]       |[3]     |
// |1  |1  |1  |[apple, mango] |[1, 1]  |
// |2  |1  |1  |[mango]        |[1]     |
// +---+---+---+---------------+--------+


// listDF.rdd.getNumPartitions == 250, to validate this try to change the partition key with .groupBy("ID", "ID2") it will fall back to the default 200 value of spark.sql.shuffle.partitions setting 

val joinedDF = pivotCountriesDF.join(listDF, Seq("ID", "ID2", "ID3"))

// joinedDF.rdd.getNumPartitions == 250, the same partitions will be used for the join as well.

// +---+---+---+-------------------------------+---------------+--------+
// |ID |ID2|ID3|countries                      |items_purchased|quantity|
// +---+---+---+-------------------------------+---------------+--------+
// |1  |2  |3  |[China -> 1, USA -> 0, UK -> 0]|[banana]       |[3]     |
// |1  |1  |1  |[China -> 0, USA -> 1, UK -> 1]|[apple, mango] |[1, 1]  |
// |2  |1  |1  |[China -> 0, USA -> 0, UK -> 1]|[mango]        |[1]     |
// +---+---+---+-------------------------------+---------------+--------+

joinedDF.toJSON.show(false)

// +--------------------------------------------------------------------------------------------------------------------+
// |value                                                                                                               |
// +--------------------------------------------------------------------------------------------------------------------+
// |{"ID":1,"ID2":2,"ID3":3,"countries":{"China":1,"USA":0,"UK":0},"items_purchased":["banana"],"quantity":[3]}         |
// |{"ID":1,"ID2":1,"ID3":1,"countries":{"China":0,"USA":1,"UK":1},"items_purchased":["apple","mango"],"quantity":[1,1]}|
// |{"ID":2,"ID2":1,"ID3":1,"countries":{"China":0,"USA":0,"UK":1},"items_purchased":["mango"],"quantity":[1]}          |
// +--------------------------------------------------------------------------------------------------------------------++

Good luck and let me know if you need any clarification.

User contributions licensed under: CC BY-SA
3 People found this is helpful
Advertisement