How can I pivot on multiple columns separately in PySpark

Question:

Is there a possibility to make a pivot for different columns at once in PySpark?
I have a dataframe like this:

from pyspark.sql import functions as sf
import pandas as pd
sdf = spark.createDataFrame(
    pd.DataFrame([[1, 'str1', 'str4'], [1, 'str1', 'str4'], [1, 'str2', 'str4'], [1, 'str2', 'str5'],
        [1, 'str3', 'str5'], [2, 'str2', 'str4'], [2, 'str2', 'str4'], [2, 'str3', 'str4'],
        [2, 'str3', 'str5']], columns=['id', 'col1', 'col2'])
)
# +----+------+------+
# | id | col1 | col2 |
# +----+------+------+
# |  1 | str1 | str4 |
# |  1 | str1 | str4 |
# |  1 | str2 | str4 |
# |  1 | str2 | str5 |
# |  1 | str3 | str5 |
# |  2 | str2 | str4 |
# |  2 | str2 | str4 |
# |  2 | str3 | str4 |
# |  2 | str3 | str5 |
# +----+------+------+

I want to pivot it on multiple columns ("col1", "col2", …) to have a dataframe that looks like this:

+----+-----------+-----------+-----------+-----------+-----------+
| id | col1_str1 | col1_str2 | col1_str3 | col2_str4 | col2_str5 |
+----+-----------+-----------+-----------+-----------+-----------+
|  1 |         2 |         2 |         1 |         3 |         3 |
|  2 |         0 |         2 |         2 |         3 |         1 |
+----+-----------+-----------+-----------+-----------+-----------+

I’ve found a solution that works:

sdf_pivot_col1 = (
    sdf
    .groupby('id')
    .pivot('col1')
    .agg(sf.count('id'))
)
sdf_pivot_col2 = (
    sdf
    .groupby('id')
    .pivot('col2')
    .agg(sf.count('id'))
)

sdf_result = (
    sdf
    .select('id').distinct()
    .join(sdf_pivot_col1, on = 'id' , how = 'left')
    .join(sdf_pivot_col2, on = 'id' , how = 'left')
).show()

# +---+----+----+----+----+----+
# | id|str1|str2|str3|str4|str5|
# +---+----+----+----+----+----+
# |  1|   2|   2|   1|   3|   2|
# |  2|null|   2|   2|   3|   1|
# +---+----+----+----+----+----+

But I’m looking for a more compact solution.

Asked By: PaulH

||

Answers:

With the link of @mrjoseph I came up with the following solution:
It works, it’s more clean, but I still don’t like the joins…

def pivot_udf(df, *cols):
    mydf = df.select('id').drop_duplicates()
    for c in cols:
        mydf = mydf.join(
            df
            .withColumn('combcol',sf.concat(sf.lit('{}_'.format(c)),df[c]))
            .groupby('id.pivot('combcol.agg(sf.count(c)),
            how = 'left', 
            on = 'id'
        )
    return mydf

pivot_udf(sdf, 'col1','col2').show()

+---+---------+---------+---------+---------+---------+
| id|col1_str1|col1_str2|col1_str3|col2_str4|col2_str5|
+---+---------+---------+---------+---------+---------+
|  1|        2|        2|        1|        3|        2|
|  2|     null|        2|        2|        3|        1|
+---+---------+---------+---------+---------+---------+
Answered By: PaulH

Try this:

from functools import reduce
from pyspark.sql import DataFrame

cols = [x for x in sdf.columns if x!='id']
df_array = [sdf.withColumn('col', F.concat(F.lit(x), F.lit('_'), F.col(x))).select('id', 'col') for x in cols]

reduce(DataFrame.unionAll, df_array).groupby('id').pivot('col').agg(F.count('col')).show()

Output:

+---+---------+---------+---------+---------+---------+
| id|col1_str1|col1_str2|col1_str3|col2_str4|col2_str5|
+---+---------+---------+---------+---------+---------+
|  1|        2|        2|        1|        3|        2|
|  2|     null|        2|        2|        3|        1|
+---+---------+---------+---------+---------+---------+
Answered By: Ala Tarighati

What you want here is not pivoting on multiple columns (this is pivoting on multiple columns).
What you really want is pivoting on one column, but first moving both column values into one…

from pyspark.sql import functions as F

cols = [c for c in sdf.columns if c!= 'id']
sdf = (sdf
    .withColumn('_pivot', F.explode(F.array(
        *[F.concat(F.lit(f'{c}_'), c) for c in cols]
    ))).groupBy('id').pivot('_pivot').count().fillna(0)
)

sdf.show()
# +---+---------+---------+---------+---------+---------+
# | id|col1_str1|col1_str2|col1_str3|col2_str4|col2_str5|
# +---+---------+---------+---------+---------+---------+
# |  1|        2|        2|        1|        3|        2|
# |  2|        0|        2|        2|        3|        1|
# +---+---------+---------+---------+---------+---------+
Answered By: ZygD