Pyspark create new column based on other column with multiple condition with list or set

Question:

I am trying to create a new column in pyspark data frame. I have data like following

+------+
|letter|
+------+
|     A|
|     C|
|     A|
|     Z|
|     E|
+------+

I want to add a new column based on the given column according to

+------+-----+
|letter|group|
+------+-----+
|     A|   c1|
|     B|   c1|
|     F|   c2|
|     G|   c2|
|     I|   c3|
+------+-----+

There can be multiple categories, with many individual values of letters (around 100, also containing multiple letters)

I have done this with udf, and working well

from pyspark.sql.functions import udf
from pyspark.sql.types import *

c1 = ['A','B','C','D']
c2 = ['E','F','G','H']
c3 = ['I','J','K','L']
...

def l2c(value):
    if value in c1: return 'c1'
    elif value in c2: return 'c2'
    elif value in c3: return 'c3'
    else: return "na"

udf_l2c = udf(l2c, StringType())
data_with_category = data.withColumn("group", udf_l2c("letter"))

Now I am trying to do it without udf. Maybe using when and col. What I have tried is following. It is working, but very long code.

data_with_category = data.withColumn('group', when(col('letter') == 'A' ,'c1')
    .when(col('letter') == 'B', 'c1')
    .when(col('letter') == 'F', 'c2')
    ... 

It is very long and not very good to write new when condition for all possible values of letter. The number of letters can be very large (around 100) in my case. so I tried

data_with_category = data.withColumn('group', when(col('letter') in ['A','B','C','D'] ,'c1')
    .when(col('letter') in ['E','F','G','H'], 'c2')
    .when(col('letter') in ['I','J','K','L'], 'c3')

But it returns error. How can I solve this?

Asked By: Prabhu

||

Answers:

you can try to using udf,
for example:

say_hello_udf = udf(lambda name: say_hello(name), StringType())
df = spark.createDataFrame([("Rick,"),("Morty,")], ["name"])
df.withColumn("greetings", say_hello_udf(col("name")).show()

or

@udf(returnType=StringType())
def say_hello(name):
   return f"Hello {name}"
df.withColumn("greetings", say_hello(col("name")).show()
Answered By: Phạm Ngọc Quý

Use isin.

c1 = ['A','B','C','D']
c2 =['E','F','G','H']
c3 =['I','J','K','L']

df.withColumn("group", F.when(F.col("letter").isin(c1),F.lit('c1'))
                        .when(F.col("letter").isin(c2),F.lit('c2'))
                        .when(F.col("letter").isin(c3),F.lit('c3'))).show()

#+------+-----+
#|letter|group|
#+------+-----+
#|     A|   c1|
#|     B|   c1|
#|     F|   c2|
#|     G|   c2|
#|     I|   c3|
#+------+-----+
Answered By: murtihash
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.