How to sort dataframe nested array column in PySpark by specific inner element


I have a dataframe where I am using groupBy on the key and using collect_list to create an array of struct using col1 and col2. I want to sort the structs inside collect_list by the 2nd element (col2) after forming the collect_list.
I am not sure if I sort the dataframe by col2 initially and then do collect_list it will preserve the sort order or not (I found yes and no both answers in Spark). So I prefer to sort after collect_list is created since my next logic depends on the sort order. I tried udf which sometimes works, but sometimes it’s throwing an error.

import pyspark.sql.functions as F
from pyspark.sql.functions import collect_list, collect_set, expr, struct
import operator
from operator import itemgetter

def sorter(key_value_list):
  res= sorted(key_value_list, key=lambda x:x[1], reverse=True)
  return [ [item[0], item[1]] for item in res]

The return here (return [ [item[0], item[1]] for item in res]) I also tried the below, but nothing worked. Only the above statement works sometimes. But on the bulk data it shows error.

return [ concat_ws('|',[item[0],item[1]) for item in res]
return [ array([item[0],item[1]) for item in res]

sort_udf = F.udf(sorter)
df1=df.groupBy("group_key").agg( F.collect_list ( F.struct("col1","col2")).alias("key_value"))
df1.withColumn("sorted_key_value", sort_udf("key_value")).show(truncate=False)


group_key col1 col2
123       a    5
123       a    6
123       b    6
123       cd   3 
123       d    2
123       ab   9
456       d    4  
456       ad   6 
456       ce   7 
456       a    4 
456       s    3 

Normal output without sorting:

group_key key_value_arr
123 [[a, 5], [a, 6], [b, 6], [cd, 3], [d 2], [ab, 9]]
456 [[d, 4], [ad, 6], [ce, 7], [a, 4], [s, 3]]

Intended output. When I get this output I get a string returned. I want a array of string.

group_key key_value_arr
123 [[ab, 9], [a, 6], [b, 6], [a, 5], [cd, 3], [d 2]]
456 [[ce, 7], [ad, 6], [d, 4], [a, 4], [s, 3]]

Error on bulk data:

  File "/hadoop/6/yarn/local/usercache/b_incdata_rw/appcache/application_1660704390900_2796904/container_e3797_1660704390900_2796904_01_000002/", line 328, in get_return_value
    format(target_id, ".", name), value)
Py4JJavaError: An error occurred while calling o184.showString.

Another way I tried to return array of string (col1 | delimited with col2)

def sorter(key_value_list):
    l = []
    s = ""
    res = sorted(key_value_list, key=lambda x:x[1], reverse=True)
    for item in res:
        s = F.concat_ws('|', item[0], item[1])
    return l

sort_udf = F.udf(sorter, ArrayType(StringType()))

df6 = df4.withColumn("sorted_key_value", sort_udf("key_value"))

I tried to return res directly as a list that also has same error.

 File "/hadoop/10/yarn/local/usercache/b_incdata_rw/appcache/application_1660704390900_2799810/container_e3797_1660704390900_2799810_01_000453/", line 1398, in concat_ws
    return Column(sc._jvm.functions.concat_ws(sep, _to_seq(sc, cols, _to_java_column)))
AttributeError: 'NoneType' object has no attribute '_jvm'

3rd approach tried

def def_sort(x):
        return sorted(x, key=lambda x:x.split('|')[1], reverse=True)

udf_sort = F.udf(def_sort, ArrayType(StringType()))

df_full_multi2.withColumn("sorted_list", array_distinct(udf_sort("key_value"))).show(100, truncate=False)

I get intended result as below

group_key sorted_list
123 [ab|9, a|6, b|6, a|5, cd|3, d|2]
456 [[ce|7, ad|6, d|4, a|4, s|3]

However, when I run it write to parquet I get error

An error occurred while calling o178.parquet.Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: Task 807 in stage 2.0 failed 4 times, most recent failure: Lost task 807.3 in stage 2.0 (TID 3495,, executor 1006): ExecutorLostFailure (executor 1006 exited caused by one of the running tasks) Reason: Container killed by YARN for exceeding memory limits. 21.1 GB of 21 GB physical memory used. Consider boosting spark.yarn.executor.memoryOverhead.
Driver stacktrace:  at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:196)  File "/hadoop/10/yarn/local/usercache/b_incdata_rw/appcache/application_1663728370843_1731/container_e3798_1663728370843_1731_01_000001/", line 328, in get_return_value
    format(target_id, ".", name), value)
Py4JJavaError: An error occurred while calling o178.parquet.
Asked By: user14297339



  • create struct the opposite way – first "col2", then "col1"
  • sort the array descending using sort_array(... , False)
  • flip the fields inside struct using transform


from pyspark.sql import functions as F
df = spark.createDataFrame(
    [(123, 'a', 5),
     (123, 'a', 6),
     (123, 'b', 6),
     (123, 'cd', 3 ),
     (123, 'd', 2),
     (123, 'ab', 9),
     (456, 'd', 4  ),
     (456, 'ad', 6 ),
     (456, 'ce', 7 ),
     (456, 'a', 4 ),
     (456, 's', 3 )],
    ['group_key', 'col1', 'col2'])


df1 = (df
    .agg(F.sort_array(F.collect_list(F.struct("col2", "col1")), False).alias("key_value"))
df2 = df1.withColumn("key_value", F.expr("transform(key_value, x -> struct(x.col1, x.col2))"))
# +---------+--------------------------------------------------+
# |group_key|key_value                                         |
# +---------+--------------------------------------------------+
# |123      |[{ab, 9}, {b, 6}, {a, 6}, {a, 5}, {cd, 3}, {d, 2}]|
# |456      |[{ce, 7}, {ad, 6}, {d, 4}, {a, 4}, {s, 3}]        |
# +---------+--------------------------------------------------+

If you need a more advanced sorting, you may create a comparator function. Refer to this question for examples on sorting arrays of struct.

In case you want array of string:, use this:

df1 = (df
    .agg(F.sort_array(F.collect_list(F.struct("col2", "col1")), False).alias("key_value"))
df2 = df1.withColumn("key_value", F.expr("transform(key_value, x -> concat(x.col1, '|', x.col2))"))
# +---------+--------------------------------+
# |group_key|key_value                       |
# +---------+--------------------------------+
# |123      |[ab|9, b|6, a|6, a|5, cd|3, d|2]|
# |456      |[ce|7, ad|6, d|4, a|4, s|3]     |
# +---------+--------------------------------+
Answered By: ZygD

I see another issue now. My next logic is to take this key_value list and traverse by the sort order. The first col1 I encounter, I store it to another list "t", if it has 2 tokens like ‘a’ and ‘b’ as in ‘ab’ , I store them to "t" as well. If I find ‘a’, ‘b’ or ‘ab’ in any of the next col1 I discard them. If there are multiple tokens with same col2 value I need to sort those tokens lexicographically so that the logic produces same result every time since I use the col2 values in the t_val list for final calculation.

`def func(kv):
    for x in kv:
        key = str(x.split('|')[0])
        value = Decimal(str(x.split('|')[1]))
        if (key in t) or ( (len(key)==2 and key[0] in t) or (len(key)==2 and key[1] in t) ):
        elif len(key)==2:
        elif len(key)==1: 
udf_c = udf(func, StringType())
r ='key_value').alias('fnl_c'))

    group_key   key_value
    123 [ab|9, ad|9, ac|9, a|5, cd|3, e|2 d|2]
    456 [ce|7, ad|7, d|4, a|4, s|3 k|3]

desired output - lexicographically sorted by col1
    group_key   key_value
    123 [ab|9, ac|9, ad|9, a|5, cd|3, d|2 e|2]
    456 [ad|7, ce|7, a|4, d|4, k|3 s|3] 
Answered By: user14297339