How to coalesce multiple pyspark arrays?
Question:
I have an arbitrary number of arrays of equal length in a PySpark DataFrame. I need to coalesce these, element by element, into a single list. The problem with coalesce is that it doesn’t work by element, but rather selects the entire first non-null array. Any suggestions for how to accomplish this would be appreciated. Please see the test case below for an example of expected input and output:
def test_coalesce_elements():
"""
Test array coalescing on a per-element basis
"""
from pyspark.sql import SparkSession
import pyspark.sql.types as t
import pyspark.sql.functions as f
spark = SparkSession.builder.getOrCreate()
data = [
{
"a": [None, 1, None, None],
"b": [2, 3, None, None],
"c": [5, 6, 7, None],
}
]
schema = t.StructType([
t.StructField('a', t.ArrayType(t.IntegerType())),
t.StructField('b', t.ArrayType(t.IntegerType())),
t.StructField('c', t.ArrayType(t.IntegerType())),
])
df = spark.createDataFrame(data, schema)
# Inspect schema
df.printSchema()
# root
# | -- a: array(nullable=true)
# | | -- element: integer(containsNull=true)
# | -- b: array(nullable=true)
# | | -- element: integer(containsNull=true)
# | -- c: array(nullable=true)
# | | -- element: integer(containsNull=true)
# Inspect df values
df.show(truncate=False)
# +---------------------+------------------+---------------+
# |a |b |c |
# +---------------------+------------------+---------------+
# |[null, 1, null, null]|[2, 3, null, null]|[5, 6, 7, null]|
# +---------------------+------------------+---------------+
# This obviously does not work, but hopefully provides the general idea
# Remember: this will need to work with an arbitrary and dynamic set of columns
input_cols = ['a', 'b', 'c']
df = df.withColumn('d', f.coalesce(*[f.col(i) for i in input_cols]))
# This is the expected output I would like to see for the given inputs
assert df.collect()[0]['d'] == [2, 1, 7, None]
Thanks in advance for any ideas!
Answers:
Although it would be ideal, I am not sure if there is an elegant way to do this using only pyspark functions.
What I did is write a udf
that takes in an variable number of columns (using *args
, which you can read about here), and return an array of integers.
@f.udf(returnType=t.ArrayType(t.IntegerType()))
def get_array_non_null_first_element(*args):
data_array = [item for item in args]
array_lengths = [len(array) for array in data_array]
## check that all of the arrays have the same length
assert(len(set(array_lengths)) == 1)
## if they do, then you can set the array length
array_length = array_lengths[0]
first_value_array = []
for i in range(array_length):
element_array = [array[i] for array in data_array]
value = None
for x in element_array:
if x is not None:
value = x
break
else:
continue
first_value_array.append(value)
return first_value_array
Then create a new column d
by applying this udf to whichever columns you like:
df.withColumn("d", get_array_non_null_first_element(F.col('a'), F.col('b'), F.col('c'))).show()
+--------------------+------------------+---------------+---------------+
| a| b| c| d|
+--------------------+------------------+---------------+---------------+
|[null, 1, null, n...|[2, 3, null, null]|[5, 6, 7, null]|[2, 1, 7, null]|
+--------------------+------------------+---------------+---------------+
Well, as Derek and OP have said, Derek’s answer works but it would be better if we avoid using UDFs, so here is a way to accomplish it natively,
from pyspark.sql.window import Window
# Give it any static value as we just want row number for all the rows present in DataFrame
w = Window().orderBy(F.lit('A'))
# Will be used later tp join df with second df containing the calculated "d" column
df = df.withColumn("row_num", F.row_number().over(w))
print("DF:")
df.show(truncate=False)
# Input Columns
input_cols = ['a', 'b', 'c']
# Zip all the array using array_zip
# Explode the zipped array
# Create the new columns from the exploded zipped array to get single values
# Coalesce to get the first non-null value
# group by row_num as we want to bring all the values back in one array
# First convert to array before using collect_list as it ignore "null" values and the flatten the nested array to get one single flat array
df_2 = df.withColumn("new", F.arrays_zip(*input_cols))
.withColumn("new", F.explode("new"))
.select("row_num", *[F.col(f"new.{i}").alias(f"new_{i}") for i in input_cols])
.withColumn("d", F.coalesce(*[(F.col(f"new_{i}")) for i in input_cols]))
.groupBy("row_num")
.agg(F.flatten(F.collect_list(F.array("d"))).alias("d"))
print("Second DF:")
df_2.show(truncate=False)
# Join based on the row_num
final_df = df.join(df_2, df["row_num"] == df_2["row_num"], "inner")
.drop("row_num")
# voilĂ
print("Final DF:")
final_df.show(truncate = False)
assert final_df.collect()[0]["d"] == [2, 1, 7, None]
DF:
+---------------------+------------------+---------------+-------+
|a |b |c |row_num|
+---------------------+------------------+---------------+-------+
|[null, 1, null, null]|[2, 3, null, null]|[5, 6, 7, null]|1 |
+---------------------+------------------+---------------+-------+
Second DF:
+-------+---------------+
|row_num|d |
+-------+---------------+
|1 |[2, 1, 7, null]|
+-------+---------------+
Final DF:
+---------------------+------------------+---------------+---------------+
|a |b |c |d |
+---------------------+------------------+---------------+---------------+
|[null, 1, null, null]|[2, 3, null, null]|[5, 6, 7, null]|[2, 1, 7, null]|
+---------------------+------------------+---------------+---------------+
Thanks to Derek and Tushar for their responses! I was able to use them as a basis to solve the issue without a UDF, join, or explode.
Generally speaking, joins are unfavorable due to being computationally and memory expensive, UDFs can be computationally intensive, and explode can be memory intensive. Fortunately we can avoid all of these using transform:
def add_coalesced_list_by_elements_col(
df: DataFrame,
cols: List[Union[Column, str]],
col_name: str,
) -> DataFrame:
"""
Adds a new column representing a list that is collected by element from the
input set. Please note that all provided this does not check that all provided
columns are of equal length.
Args:
df: Input DataFrame to add column to
cols: List of columns to collect by element. All columns should be of equal length.
col_name: The name of the new column
Returns:
DataFrame with result added as a new column.
"""
return (
df
.withColumn(
col_name,
f.transform(
# Doesn't matter which col, we don't use this val
cols[0],
# We use i + 1 because sql array index starts at 1, while transform index starts at 0
lambda _, i: f.coalesce(*[f.element_at(c, i + 1) for c in cols]))
)
)
def test_collect_list_elements():
from typing import List
import pyspark.sql.functions as f
import pyspark.sql.types as t
from pyspark.sql import SparkSession, DataFrame, Column, Window
# Arrange
spark = SparkSession.builder.getOrCreate()
data = [
{
"id": 1,
"a": [None, 1, None, None],
"b": [2, 3, None, None],
"c": [5, 6, 7, None],
}
]
schema = t.StructType(
[
t.StructField("id", t.IntegerType()),
t.StructField("a", t.ArrayType(t.IntegerType())),
t.StructField("b", t.ArrayType(t.IntegerType())),
t.StructField("c", t.ArrayType(t.IntegerType())),
]
)
df = spark.createDataFrame(data, schema)
# Act
df = add_coalesced_list_by_elements_col(df=df, cols=["a", "b", "c"], col_name="d")
# Assert new col is correct output
assert df.collect()[0]["d"] == [2, 1, 7, None]
# Assert all the other cols are not affected
assert df.collect()[0]["a"] == [None, 1, None, None]
assert df.collect()[0]["b"] == [2, 3, None, None]
assert df.collect()[0]["c"] == [5, 6, 7, None]
I have an arbitrary number of arrays of equal length in a PySpark DataFrame. I need to coalesce these, element by element, into a single list. The problem with coalesce is that it doesn’t work by element, but rather selects the entire first non-null array. Any suggestions for how to accomplish this would be appreciated. Please see the test case below for an example of expected input and output:
def test_coalesce_elements():
"""
Test array coalescing on a per-element basis
"""
from pyspark.sql import SparkSession
import pyspark.sql.types as t
import pyspark.sql.functions as f
spark = SparkSession.builder.getOrCreate()
data = [
{
"a": [None, 1, None, None],
"b": [2, 3, None, None],
"c": [5, 6, 7, None],
}
]
schema = t.StructType([
t.StructField('a', t.ArrayType(t.IntegerType())),
t.StructField('b', t.ArrayType(t.IntegerType())),
t.StructField('c', t.ArrayType(t.IntegerType())),
])
df = spark.createDataFrame(data, schema)
# Inspect schema
df.printSchema()
# root
# | -- a: array(nullable=true)
# | | -- element: integer(containsNull=true)
# | -- b: array(nullable=true)
# | | -- element: integer(containsNull=true)
# | -- c: array(nullable=true)
# | | -- element: integer(containsNull=true)
# Inspect df values
df.show(truncate=False)
# +---------------------+------------------+---------------+
# |a |b |c |
# +---------------------+------------------+---------------+
# |[null, 1, null, null]|[2, 3, null, null]|[5, 6, 7, null]|
# +---------------------+------------------+---------------+
# This obviously does not work, but hopefully provides the general idea
# Remember: this will need to work with an arbitrary and dynamic set of columns
input_cols = ['a', 'b', 'c']
df = df.withColumn('d', f.coalesce(*[f.col(i) for i in input_cols]))
# This is the expected output I would like to see for the given inputs
assert df.collect()[0]['d'] == [2, 1, 7, None]
Thanks in advance for any ideas!
Although it would be ideal, I am not sure if there is an elegant way to do this using only pyspark functions.
What I did is write a udf
that takes in an variable number of columns (using *args
, which you can read about here), and return an array of integers.
@f.udf(returnType=t.ArrayType(t.IntegerType()))
def get_array_non_null_first_element(*args):
data_array = [item for item in args]
array_lengths = [len(array) for array in data_array]
## check that all of the arrays have the same length
assert(len(set(array_lengths)) == 1)
## if they do, then you can set the array length
array_length = array_lengths[0]
first_value_array = []
for i in range(array_length):
element_array = [array[i] for array in data_array]
value = None
for x in element_array:
if x is not None:
value = x
break
else:
continue
first_value_array.append(value)
return first_value_array
Then create a new column d
by applying this udf to whichever columns you like:
df.withColumn("d", get_array_non_null_first_element(F.col('a'), F.col('b'), F.col('c'))).show()
+--------------------+------------------+---------------+---------------+
| a| b| c| d|
+--------------------+------------------+---------------+---------------+
|[null, 1, null, n...|[2, 3, null, null]|[5, 6, 7, null]|[2, 1, 7, null]|
+--------------------+------------------+---------------+---------------+
Well, as Derek and OP have said, Derek’s answer works but it would be better if we avoid using UDFs, so here is a way to accomplish it natively,
from pyspark.sql.window import Window
# Give it any static value as we just want row number for all the rows present in DataFrame
w = Window().orderBy(F.lit('A'))
# Will be used later tp join df with second df containing the calculated "d" column
df = df.withColumn("row_num", F.row_number().over(w))
print("DF:")
df.show(truncate=False)
# Input Columns
input_cols = ['a', 'b', 'c']
# Zip all the array using array_zip
# Explode the zipped array
# Create the new columns from the exploded zipped array to get single values
# Coalesce to get the first non-null value
# group by row_num as we want to bring all the values back in one array
# First convert to array before using collect_list as it ignore "null" values and the flatten the nested array to get one single flat array
df_2 = df.withColumn("new", F.arrays_zip(*input_cols))
.withColumn("new", F.explode("new"))
.select("row_num", *[F.col(f"new.{i}").alias(f"new_{i}") for i in input_cols])
.withColumn("d", F.coalesce(*[(F.col(f"new_{i}")) for i in input_cols]))
.groupBy("row_num")
.agg(F.flatten(F.collect_list(F.array("d"))).alias("d"))
print("Second DF:")
df_2.show(truncate=False)
# Join based on the row_num
final_df = df.join(df_2, df["row_num"] == df_2["row_num"], "inner")
.drop("row_num")
# voilĂ
print("Final DF:")
final_df.show(truncate = False)
assert final_df.collect()[0]["d"] == [2, 1, 7, None]
DF:
+---------------------+------------------+---------------+-------+
|a |b |c |row_num|
+---------------------+------------------+---------------+-------+
|[null, 1, null, null]|[2, 3, null, null]|[5, 6, 7, null]|1 |
+---------------------+------------------+---------------+-------+
Second DF:
+-------+---------------+
|row_num|d |
+-------+---------------+
|1 |[2, 1, 7, null]|
+-------+---------------+
Final DF:
+---------------------+------------------+---------------+---------------+
|a |b |c |d |
+---------------------+------------------+---------------+---------------+
|[null, 1, null, null]|[2, 3, null, null]|[5, 6, 7, null]|[2, 1, 7, null]|
+---------------------+------------------+---------------+---------------+
Thanks to Derek and Tushar for their responses! I was able to use them as a basis to solve the issue without a UDF, join, or explode.
Generally speaking, joins are unfavorable due to being computationally and memory expensive, UDFs can be computationally intensive, and explode can be memory intensive. Fortunately we can avoid all of these using transform:
def add_coalesced_list_by_elements_col(
df: DataFrame,
cols: List[Union[Column, str]],
col_name: str,
) -> DataFrame:
"""
Adds a new column representing a list that is collected by element from the
input set. Please note that all provided this does not check that all provided
columns are of equal length.
Args:
df: Input DataFrame to add column to
cols: List of columns to collect by element. All columns should be of equal length.
col_name: The name of the new column
Returns:
DataFrame with result added as a new column.
"""
return (
df
.withColumn(
col_name,
f.transform(
# Doesn't matter which col, we don't use this val
cols[0],
# We use i + 1 because sql array index starts at 1, while transform index starts at 0
lambda _, i: f.coalesce(*[f.element_at(c, i + 1) for c in cols]))
)
)
def test_collect_list_elements():
from typing import List
import pyspark.sql.functions as f
import pyspark.sql.types as t
from pyspark.sql import SparkSession, DataFrame, Column, Window
# Arrange
spark = SparkSession.builder.getOrCreate()
data = [
{
"id": 1,
"a": [None, 1, None, None],
"b": [2, 3, None, None],
"c": [5, 6, 7, None],
}
]
schema = t.StructType(
[
t.StructField("id", t.IntegerType()),
t.StructField("a", t.ArrayType(t.IntegerType())),
t.StructField("b", t.ArrayType(t.IntegerType())),
t.StructField("c", t.ArrayType(t.IntegerType())),
]
)
df = spark.createDataFrame(data, schema)
# Act
df = add_coalesced_list_by_elements_col(df=df, cols=["a", "b", "c"], col_name="d")
# Assert new col is correct output
assert df.collect()[0]["d"] == [2, 1, 7, None]
# Assert all the other cols are not affected
assert df.collect()[0]["a"] == [None, 1, None, None]
assert df.collect()[0]["b"] == [2, 3, None, None]
assert df.collect()[0]["c"] == [5, 6, 7, None]