Pivot array of structs into columns using pyspark – not explode the array

Question:

I currently have a dataframe with an id and a column which is an array of structs:

 root
 |-- id: string (nullable = true)
 |-- lists: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- _1: string (nullable = true)
 |    |    |-- _2: string (nullable = true)

Here is an example table with data:

 id | list1             | list2
 ------------------------------------------
 1  | [[a, av], [b, bv]]| [[e, ev], [f,fv]]
 2  | [[c, cv]]         | [[g,gv]]

How do I transform the above dataframe to the one below? I need to “explode” the array and add columns based on first value in the struct.

 id | a   | b   | c   | d   | e  | f  | g  
 ----------------------------------------
 1  | av  | bv  | null| null| ev | fv | null
 2  | null| null| cv  | null|null|null|gv

A pyspark code to create the dataframe is as below:

d1 = spark.createDataFrame([("1", [("a","av"),("b","bv")], [("e", "ev"), ("f", "fv")]), 
                                    ("2", [("c", "cv")],  [("g", "gv")])], ["id","list1","list2"])

Note: I have a spark version of 2.2.0 so some sql functions don’t work such as concat_map, etc.

Asked By: cody

||

Answers:

You can do this using hogher order functions without exploding the arrays like:

d1.select('id',
          f.when(f.size(f.expr('''filter(list1,x->x._1='a')'''))>0,f.concat_ws(',',f.expr('''transform(filter(list1,x->x._1='a'),value->value._2)'''))).alias('a'),
          f.when(f.size(f.expr('''filter(list1,x->x._1='b')'''))>0,f.concat_ws(',',f.expr('''transform(filter(list1,x->x._1='b'),value->value._2)'''))).alias('b'),
          f.when(f.size(f.expr('''filter(list1,x->x._1='c')'''))>0,f.concat_ws(',',f.expr('''transform(filter(list1,x->x._1='c'),value->value._2)'''))).alias('c'),
          f.when(f.size(f.expr('''filter(list1,x->x._1='d')'''))>0,f.concat_ws(',',f.expr('''transform(filter(list1,x->x._1='d'),value->value._2)'''))).alias('d'),
          f.when(f.size(f.expr('''filter(list2,x->x._1='e')'''))>0,f.concat_ws(',',f.expr('''transform(filter(list2,x->x._1='e'),value->value._2)'''))).alias('e'),
          f.when(f.size(f.expr('''filter(list2,x->x._1='f')'''))>0,f.concat_ws(',',f.expr('''transform(filter(list2,x->x._1='f'),value->value._2)'''))).alias('f'),
          f.when(f.size(f.expr('''filter(list2,x->x._1='g')'''))>0,f.concat_ws(',',f.expr('''transform(filter(list2,x->x._1='g'),value->value._2)'''))).alias('g'),
          f.when(f.size(f.expr('''filter(list2,x->x._1='h')'''))>0,f.concat_ws(',',f.expr('''transform(filter(list2,x->x._1='h'),value->value._2)'''))).alias('h')
          ).show()


+---+----+----+----+----+----+----+----+----+
| id|   a|   b|   c|   d|   e|   f|   g|   h|
+---+----+----+----+----+----+----+----+----+
|  1|  av|  bv|null|null|  ev|  fv|null|null|
|  2|null|null|  cv|null|null|null|  gv|null|
+---+----+----+----+----+----+----+----+----+

Hope it helps

Answered By: Shubham Jain

UPD – For Spark 2.2.0

You can define similar functions in 2.2.0 using udfs. They will be much less efficient in terms of performance and you’ll need a special function for each output value type (i.e. you won’t be able to have one element_at function which could output value of any type from any map type), but they will work. The code below works for Spark 2.2.0:

from pyspark.sql.functions import udf
from pyspark.sql.types import MapType, ArrayType, StringType

@udf(MapType(StringType(), StringType()))
def map_from_entries(l):
    return {x:y for x,y in l}

@udf(MapType(StringType(), StringType()))
def map_concat(m1, m2):
    m1.update(m2)
    return m1

@udf(ArrayType(StringType()))
def map_keys(m):
    return list(m.keys())

def element_getter(k):
    @udf(StringType())
    def element_at(m):
        return m.get(k)
    return element_at

d2 = d1.select('id',
               map_concat(map_from_entries('list1'),
                          map_from_entries('list2')).alias('merged_map'))
map_keys = d2.select(f.explode(map_keys('merged_map')).alias('mk')) 
             .agg(f.collect_set('mk').alias('keys')) 
             .collect()[0].keys
map_keys = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
selects = [element_getter(k)('merged_map').alias(k) for k in sorted(map_keys)]
d = d2.select('id', *selects) 

ORIGINAL ANSWER (working for Spark 2.4.0+)

Not clear where d column came from in your example (d never appeared in the initial dataframe). If columns should be created based on the first elements in the array, then this should work (assuming total number of unique first values in the lists is small enough):

import pyspark.sql.functions as f
d2 = d1.select('id',
               f.map_concat(f.map_from_entries('list1'),
                            f.map_from_entries('list2')).alias('merged_map'))
map_keys = d2.select(f.explode(f.map_keys('merged_map')).alias('mk')) 
             .agg(f.collect_set('mk').alias('keys')) 
             .collect()[0].keys
selects = [f.element_at('merged_map', k).alias(k) for k in sorted(map_keys)]
d = d2.select('id', *selects)

Output (no column for d because it never mentioned in the initial DataFrame):

+---+----+----+----+----+----+----+
| id|   a|   b|   c|   e|   f|   g|
+---+----+----+----+----+----+----+
|  1|  av|  bv|null|  ev|  fv|null|
|  2|null|null|  cv|null|null|  gv|
+---+----+----+----+----+----+----+

If you actually had in mind that list of the columns is fixed from the beginning (and they are not taken from the array), then you can just replace the definition of varaible map_keys with the fixed list of columns, e.g. map_keys=['a', 'b', 'c', 'd', 'e', 'f', 'g']. In that case you get the output you mention in the answer:

+---+----+----+----+----+----+----+----+
| id|   a|   b|   c|   d|   e|   f|   g|
+---+----+----+----+----+----+----+----+
|  1|  av|  bv|null|null|  ev|  fv|null|
|  2|null|null|  cv|null|null|null|  gv|
+---+----+----+----+----+----+----+----+

By the way – what you want to do is not what is called explode in Spark. explode in Spark is for the situation when you create multiple rows from one. E.g. if you wanted to get from dataframe like this:

+---+---------+
| id|      arr|
+---+---------+
|  1|   [a, b]|
|  2|[c, d, e]|
+---+---------+

to this:

+---+-------+
| id|element|
+---+-------+
|  1|      a|
|  1|      b|
|  2|      c|
|  2|      d|
|  2|      e|
+---+-------+
Answered By: Alexander Pivovarov
from pyspark.sql.functions import explode, first, concat

d1 = spark.createDataFrame([("1", [("a", "av"), ("b", "bv")], [("e", "ev"), ("f", "fv")]), 
                            ("2", [("c", "cv")], [("g", "gv")])], ["id", "list1", "list2"])

d2 = d1.withColumn('concat', concat('list1', 'list2'))
d3 = d2.withColumn('explode', explode('concat'))
d4 = d3.groupby('id').pivot('explode._1').agg(first('explode._2'))

d4.show()
+---+----+----+----+----+----+----+
|id |a   |b   |c   |e   |f   |g   |
+---+----+----+----+----+----+----+
|1  |av  |bv  |null|ev  |fv  |null|
|2  |null|null|cv  |null|null|gv  |
+---+----+----+----+----+----+----+
Answered By: eugene.lebed