Pyspark – array column: combine all rows having at least one same value

Question:

I have a spark dataframe df that looks like this:

+------------+
|      values|
+------------+
|      [a, b]|
|[a, b, c, d]|
|   [a, e, f]|
|   [w, x, y]|
|      [x, z]|
+------------+

And I want to be able to get another dataframe that looks like this:

+-------------------+
|             values|
+-------------------+
| [a, b, c, d, e, f]|
|       [w, x, y, z]|
+-------------------+

So what happened is that I’m combining all rows having at least one common value.

I’m aware that this thread exists:
Spark get all rows with same values in array in column
but I don’t think I was able to get the answer I was looking for.

I also saw this one:
Pyspark merge dataframe rows one array is contained in another
So I tried copying the code of the accepted answer, but unfortunately, still not getting my desired output

from pyspark.sql.functions import expr
        
df_sub = df.alias('d1').join(df.alias('d2'), 
    expr('size(array_except(d2.values, d1.values))==0 AND size(d2.values) < size(d1.values)')
        ).select('d2.values').distinct()
    
df.join(df_sub , on=['values'], how='left_anti') 
  .withColumn('values', expr('sort_array(values)')) 
  .distinct() 
  .show()

Output:

+------------+
|      values|
+------------+
|   [a, e, f]|
|   [w, x, y]|
|[a, b, c, d]|
|      [x, z]|
+------------+

This is probably because the original problem has bounds on the maximum length of the array. How can I solve this?

Asked By: ninjaman

||

Answers:

Given an input dataframe (say, data_sdf) as following

# +------------+---+
# |vals        |id |
# +------------+---+
# |[a, b]      |1  |
# |[a, b, c, d]|2  |
# |[a, e, f]   |3  |
# |[k, l, m]   |4  |
# |[w, x, y]   |5  |
# |[x, z]      |6  |
# +------------+---+

Notice the id field that I added. It has the data’s sort order. Also, I added a new row (see id = 4) that will not be merged with others.

data_sdf. 
    withColumn('lead_vals', func.lead('vals').over(wd.orderBy('id'))). 
    withColumn('vals_nonoverlap_flg', 
               func.abs(func.arrays_overlap('vals', 'lead_vals').cast('int') - 1)
               ). 
    withColumn('blah', func.sum('vals_nonoverlap_flg').over(wd.orderBy('id'))). 
    withColumn('fnl_val_to_merge', 
               func.when(func.row_number().over(wd.orderBy('id')) == 1, 
                         func.array_union('vals', 'lead_vals')
                         ).
               otherwise(func.col('lead_vals'))
               ). 
    groupBy('blah'). 
    agg(func.array_distinct(func.flatten(func.collect_list('fnl_val_to_merge'))).alias('merged_val')). 
    drop('blah'). 
    show(truncate=False)

# +------------------+
# |merged_val        |
# +------------------+
# |[a, b, c, d, e, f]|
# |[k, l, m]         |
# |[w, x, y, z]      |
# +------------------+

P.S., add partitionBy() within the window wherever used.


You could also use aggregate.

data_sdf. 
    groupBy(func.lit('gk').alias('gk')). 
    agg(func.collect_list('vals').alias('vals_arr')). 
    withColumn('blah',
               func.expr('''
                         aggregate(slice(vals_arr, 2, size(vals_arr)), 
                                   array(vals_arr[0]),
                                   (x, y) -> if(arrays_overlap(x[size(x)-1], y), 
                                                array_union(slice(x, 1, size(x)-1), array(array_union(x[size(x)-1], y))), 
                                                array_union(x, array(y))
                                                )
                                   )
                         ''')
               ). 
    selectExpr('explode(blah) as merged_vals'). 
    show(truncate=False)

# +------------------+
# |merged_vals       |
# +------------------+
# |[a, b, c, d, e, f]|
# |[k, l, m]         |
# |[w, x, y, z]      |
# +------------------+
Answered By: samkart