Skip to content

In SQL how do I group by every one of a long list of columns and get counts, assembled all into one table?

I have performed a stratified sample on a multi-label dataset before training a classifier and want to check how balanced it is now. The columns in the dataset are:


I want to group by every label_* column once, and create a dictionary of the results with positive/negative counts. At the moment I am accomplishing this in PySpark SQL like this:

# Evaluate how skewed the sample is after balancing it by resampling
stratified_sample ='s3://stackoverflow-events/1901/Sample.Stratified.{}.*.jsonl'.format(limit))

label_counts = {}
for i in range(0, 100):
  count_df = spark.sql('SELECT label_{}, COUNT(*) as total FROM stratified_sample GROUP BY label_{}'.format(i, i))
  rows = count_df.rdd.take(2)
  neg_count = getattr(rows[0], 'total')
  pos_count = getattr(rows[1], 'total')
  label_counts[i] = [neg_count, pos_count]

The output is thus:

{0: [1034673, 14491],
 1: [1023250, 25914],
 2: [1030462, 18702],
 3: [1035645, 13519],
 4: [1037445, 11719],
 5: [1010664, 38500],
 6: [1031699, 17465],

This feels like it should be possible in one SQL statement, but I can’t figure out how to do this or find an existing solution. Obviously I don’t want to write out all the column names and generating SQL seems worse than this solution.

Can SQL do this? Thanks!



You can generate sql without group by.

Something like

SELECT COUNT(*) AS total, SUM(label_k) as positive_k ,.. FROM table

And then use the result to produce your dict {k : [total-positive_k, positive_k]}