How to coalesce multiple pyspark arrays?

Question:

I have an arbitrary number of arrays of equal length in a PySpark DataFrame. I need to coalesce these, element by element, into a single list. The problem with coalesce is that it doesn’t work by element, but rather selects the entire first non-null array. Any suggestions for how to accomplish this would be appreciated. Please see the test case below for an example of expected input and output:

def test_coalesce_elements():
    """
    Test array coalescing on a per-element basis
    """
    from pyspark.sql import SparkSession
    import pyspark.sql.types as t
    import pyspark.sql.functions as f

    spark = SparkSession.builder.getOrCreate()

    data = [
        {
            "a": [None, 1, None, None],
            "b": [2, 3, None, None],
            "c": [5, 6, 7, None],
        }
    ]

    schema = t.StructType([
        t.StructField('a', t.ArrayType(t.IntegerType())),
        t.StructField('b', t.ArrayType(t.IntegerType())),
        t.StructField('c', t.ArrayType(t.IntegerType())),
    ])
    df = spark.createDataFrame(data, schema)

    # Inspect schema
    df.printSchema()
    # root
    # | -- a: array(nullable=true)
    # | | -- element: integer(containsNull=true)
    # | -- b: array(nullable=true)
    # | | -- element: integer(containsNull=true)
    # | -- c: array(nullable=true)
    # | | -- element: integer(containsNull=true)

    # Inspect df values
    df.show(truncate=False)
    # +---------------------+------------------+---------------+
    # |a                    |b                 |c              |
    # +---------------------+------------------+---------------+
    # |[null, 1, null, null]|[2, 3, null, null]|[5, 6, 7, null]|
    # +---------------------+------------------+---------------+

    # This obviously does not work, but hopefully provides the general idea
    # Remember: this will need to work with an arbitrary and dynamic set of columns
    input_cols = ['a', 'b', 'c']
    df = df.withColumn('d', f.coalesce(*[f.col(i) for i in input_cols]))

    # This is the expected output I would like to see for the given inputs
    assert df.collect()[0]['d'] == [2, 1, 7, None]

Thanks in advance for any ideas!

Asked By: Brendan

||

Answers:

Although it would be ideal, I am not sure if there is an elegant way to do this using only pyspark functions.

What I did is write a udf that takes in an variable number of columns (using *args, which you can read about here), and return an array of integers.

@f.udf(returnType=t.ArrayType(t.IntegerType()))
def get_array_non_null_first_element(*args):
    data_array = [item for item in args]
    array_lengths = [len(array) for array in data_array]
    
    ## check that all of the arrays have the same length
    assert(len(set(array_lengths)) == 1)
    
    ## if they do, then you can set the array length
    array_length = array_lengths[0]
    
    first_value_array = []
    for i in range(array_length):
        element_array = [array[i] for array in data_array]
        value = None
        for x in element_array:
            if x is not None:
                value = x
                break
            else:
                continue
        first_value_array.append(value)
    return first_value_array

Then create a new column d by applying this udf to whichever columns you like:

df.withColumn("d", get_array_non_null_first_element(F.col('a'), F.col('b'), F.col('c'))).show()

+--------------------+------------------+---------------+---------------+
|                   a|                 b|              c|              d|
+--------------------+------------------+---------------+---------------+
|[null, 1, null, n...|[2, 3, null, null]|[5, 6, 7, null]|[2, 1, 7, null]|
+--------------------+------------------+---------------+---------------+
Answered By: Derek O

Well, as Derek and OP have said, Derek’s answer works but it would be better if we avoid using UDFs, so here is a way to accomplish it natively,

from pyspark.sql.window import Window

# Give it any static value as we just want row number for all the rows present in DataFrame
w = Window().orderBy(F.lit('A'))

# Will be used later tp join df with second df containing the calculated "d" column
df = df.withColumn("row_num", F.row_number().over(w))

print("DF:")
df.show(truncate=False)

# Input Columns
input_cols = ['a', 'b', 'c']

# Zip all the array using array_zip
# Explode the zipped array
# Create the new columns from the exploded zipped array to get single values
# Coalesce to get the first non-null value
# group by row_num as we want to bring all the values back in one array
# First convert to array before using collect_list as it ignore "null" values and the flatten the nested array to get one single flat array
df_2 = df.withColumn("new", F.arrays_zip(*input_cols)) 
            .withColumn("new", F.explode("new")) 
            .select("row_num", *[F.col(f"new.{i}").alias(f"new_{i}") for i in input_cols]) 
            .withColumn("d", F.coalesce(*[(F.col(f"new_{i}")) for i in input_cols])) 
            .groupBy("row_num") 
            .agg(F.flatten(F.collect_list(F.array("d"))).alias("d"))

print("Second DF:")
df_2.show(truncate=False)

# Join based on the row_num
final_df = df.join(df_2, df["row_num"] == df_2["row_num"], "inner") 
                .drop("row_num")

# voilĂ 
print("Final DF:")
final_df.show(truncate = False)

assert final_df.collect()[0]["d"] == [2, 1, 7, None]
DF:
+---------------------+------------------+---------------+-------+
|a                    |b                 |c              |row_num|
+---------------------+------------------+---------------+-------+
|[null, 1, null, null]|[2, 3, null, null]|[5, 6, 7, null]|1      |
+---------------------+------------------+---------------+-------+

Second DF:
+-------+---------------+
|row_num|d              |
+-------+---------------+
|1      |[2, 1, 7, null]|
+-------+---------------+

Final DF:
+---------------------+------------------+---------------+---------------+
|a                    |b                 |c              |d              |
+---------------------+------------------+---------------+---------------+
|[null, 1, null, null]|[2, 3, null, null]|[5, 6, 7, null]|[2, 1, 7, null]|
+---------------------+------------------+---------------+---------------+
Answered By: Tushar Patil

Thanks to Derek and Tushar for their responses! I was able to use them as a basis to solve the issue without a UDF, join, or explode.

Generally speaking, joins are unfavorable due to being computationally and memory expensive, UDFs can be computationally intensive, and explode can be memory intensive. Fortunately we can avoid all of these using transform:

def add_coalesced_list_by_elements_col(
    df: DataFrame,
    cols: List[Union[Column, str]],
    col_name: str,
) -> DataFrame:
    """
    Adds a new column representing a list that is collected by element from the
    input set. Please note that all provided this does not check that all provided
    columns are of equal length.

    Args:
        df: Input DataFrame to add column to
        cols: List of columns to collect by element. All columns should be of equal length.
        col_name: The name of the new column

    Returns:
        DataFrame with result added as a new column.
    """
    return (
        df
        .withColumn(
            col_name,
            f.transform(
                # Doesn't matter which col, we don't use this val
                cols[0],
                # We use i + 1 because sql array index starts at 1, while transform index starts at 0
                lambda _, i: f.coalesce(*[f.element_at(c, i + 1) for c in cols]))
        )
    )


def test_collect_list_elements():
    from typing import List
    import pyspark.sql.functions as f
    import pyspark.sql.types as t
    from pyspark.sql import SparkSession, DataFrame, Column, Window

    # Arrange
    spark = SparkSession.builder.getOrCreate()

    data = [
        {
            "id": 1,
            "a": [None, 1, None, None],
            "b": [2, 3, None, None],
            "c": [5, 6, 7, None],
        }
    ]

    schema = t.StructType(
        [
            t.StructField("id", t.IntegerType()),
            t.StructField("a", t.ArrayType(t.IntegerType())),
            t.StructField("b", t.ArrayType(t.IntegerType())),
            t.StructField("c", t.ArrayType(t.IntegerType())),
        ]
    )
    df = spark.createDataFrame(data, schema)

    # Act
    df = add_coalesced_list_by_elements_col(df=df, cols=["a", "b", "c"], col_name="d")

    # Assert new col is correct output
    assert df.collect()[0]["d"] == [2, 1, 7, None]

    # Assert all the other cols are not affected
    assert df.collect()[0]["a"] == [None, 1, None, None]
    assert df.collect()[0]["b"] == [2, 3, None, None]
    assert df.collect()[0]["c"] == [5, 6, 7, None]
Answered By: Brendan