Skip to content
Advertisement

Filtering rows in pyspark dataframe and creating a new column that contains the result

so I am trying to identify the crime that happens within the SF downtown boundary on Sunday. My idea was to first write a UDF to label if each crime is in the area I identify as the downtown area, if it happened within the area, then it will have a label of “1” and “0” if not. After that, I am trying to create a new column to store those results. I tried my best to write everything I can but it just doesn’t work for some reason. Here is the code I wrote:

from pyspark.sql.types import BooleanType
from pyspark.sql.functions import udf

def filter_dt(x,y):
  if (((x < -122.4213) & (x > -122.4313)) & ((y > 37.7540) & (y < 37.7740))):
    return '1'
  else:
    return '0'

schema = StructType([StructField("isDT", BooleanType(), False)])
filter_dt_boolean = udf(lambda row: filter_dt(row[0], row[1]), schema)

#First, pick out the crime cases that happens on Sunday BooleanType()
q3_sunday = spark.sql("SELECT * FROM sf_crime WHERE DayOfWeek='Sunday'")
#Then, we add a new column for us to filter out(identify) if the crime is in DT
q3_final = q3_result.withColumn("isDT", filter_dt(q3_sunday.select('X'),q3_sunday.select('Y')))

The error I am getting is:Picture for the error message

My guess is that the udf I am having right now doesn’t support the whole column as input to be compared, but I don’t know how to fix it to make it work. Please help! Thank you!

Advertisement

Answer

A sample data would have helped. For now I assume that your data looks like this:

+----+---+---+
|val1|  x|  y|
+----+---+---+
|  10|  7| 14|
|   5|  1|  4|
|   9|  8| 10|
|   2|  6| 90|
|   7|  2| 30|
|   3|  5| 11|
+----+---+---+

Then you dont need a udf, as you can do the evaluation using the when() function

import pyspark.sql.functions as F
tst= sqlContext.createDataFrame([(10,7,14),(5,1,4),(9,8,10),(2,6,90),(7,2,30),(3,5,11)],schema=['val1','x','y'])
tst_res = tst.withColumn("isdt",F.when(((tst.x.between(4,10))&(tst.y.between(11,20))),1).otherwise(0))This will give the result
   tst_res.show()
+----+---+---+----+
|val1|  x|  y|isdt|
+----+---+---+----+
|  10|  7| 14|   1|
|   5|  1|  4|   0|
|   9|  8| 10|   0|
|   2|  6| 90|   0|
|   7|  2| 30|   0|
|   3|  5| 11|   1|
+----+---+---+----+

If i have got the data wrong and still you need to pass multiple values to udf, you have to pass it as an array or a struct. I prefer a struct

from pyspark.sql.functions import udf
from pyspark.sql.types import *

@udf(IntegerType())
def check_data(row):
    if((row.x in range(4,5))&(row.y in range(1,20))):
        return(1)
    else:
        return(0)
tst_res1 = tst.withColumn("isdt",check_data(F.struct('x','y')))

The result will be the same. But it is always better to avoid UDF and go for spark inbuilt functions since spark catalyst cannot understand the logic inside the udf and cannot optimize it.

1 People found this is helpful
Advertisement