How to Group by Conditional aggregation of adjacent rows In PySpark

Question:

I am facing issue when doing conditional grouping in spark dataframe

Below is complete example

I have a dataframe, which has been sorted by user and by time

 activity          location  user
0 watch movie        house    A
1 sleep              house    A
2 cardio             gym      A
3 cardio             gym      B
4 buy biscuits       shop     B
5 cardio             gym      B
6 weight training    gym      B

I only want to do the sum() when ‘location’ fields are the same across adjacent rows for a given user. So it is not just df.groupby([‘user’,’location’]).activity.collect(",") The desired output will look like the following. In addition, the order is important.

duration                  location  user
watch movie,sleep          house     A
cardio                     gym       A
cardio                     gym       B
buy biscuits               shop      B
cardio, weight training    gym       B

Similar to below but using pyspark data frame

As converting pyspark dataframe to pandas is going out of memory (due to huge dataset)
Groupby conditional sum of adjacent rows pandas

Asked By: Neha Zaveri

||

Answers:

You need two steps to do that. Assuming df is your dataframe

1. Create a group ID

from pyspark.sql import functions as F, Window as W

df = df.withColumn(
    "grp_id", F.lag("location").over(W.partitionBy("user").orderBy("time"))
).withColumn(
    "grp_id",
    F.sum(F.when(F.col("grp_id") == F.col("location"), 0).otherwise(1)).over(
        W.partitionBy("user").orderBy("time").rowsBetween(W.unboundedPreceding, 0)
    ),
)
df.show()

+----+---------------+--------+----+------+
|time|       activity|location|user|grp_id|
+----+---------------+--------+----+------+
|   0|watch movie    |   house|   A|     1|
|   1|sleep          |   house|   A|     1|
|   2|cardio         |   gym  |   A|     2|
|   3|cardio         |   gym  |   B|     1|
|   4|buy biscuits   |   shop |   B|     2|
|   5|cardio         |   gym  |   B|     3|
|   6|weight training|   gym  |   B|     3|
+----+---------------+--------+----+------+

2. Do the aggregation

df = df.groupBy("user", "grp_id", "location").agg(F.collect_list("activity"))

df.show()

+----+------+--------+----------------------------------+
|user|grp_id|location|collect_list(activity)            |
+----+------+--------+----------------------------------+
|A   |1     |house   |[watch movie    , sleep          ]|
|A   |2     |gym     |[cardio         ]                 |
|B   |1     |gym     |[cardio         ]                 |
|B   |2     |shop    |[buy biscuits   ]                 |
|B   |3     |gym     |[cardio         , weight training]|
+----+------+--------+----------------------------------+
Answered By: Steven