Pair combinations of array column in PySpark

Question:

Similar to this question (Scala), but I need combinations in PySpark (pair combinations of array column).

Example input:

df = spark.createDataFrame(
    [([0, 1],),
     ([2, 3, 4],),
     ([5, 6, 7, 8],)],
    ['array_col'])

Expected output:

+------------+------------------------------------------------+
|array_col   |out                                             |
+------------+------------------------------------------------+
|[0, 1]      |[[0, 1]]                                        |
|[2, 3, 4]   |[[2, 3], [2, 4], [3, 4]]                        |
|[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
+------------+------------------------------------------------+
Asked By: ZygD

||

Answers:

pandas_udf is an efficient and concise approach in PySpark.

from pyspark.sql import functions as F
import pandas as pd
from itertools import combinations

@F.pandas_udf('array<array<int>>')
def pudf(c: pd.Series) -> pd.Series:
    return c.apply(lambda x: list(combinations(x, 2)))


df = df.withColumn('out', pudf('array_col'))
df.show(truncate=0)
# +------------+------------------------------------------------+
# |array_col   |out                                             |
# +------------+------------------------------------------------+
# |[0, 1]      |[[0, 1]]                                        |
# |[2, 3, 4]   |[[2, 3], [2, 4], [3, 4]]                        |
# |[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
# +------------+------------------------------------------------+

Note: in some systems, instead of 'array<array<int>>' you may need to provide types from pyspark.sql.types, e.g. ArrayType(ArrayType(IntegerType())).

Answered By: ZygD

Native Spark approach. I’ve translated this answer to PySpark.

Python 3.8+ (walrus := operator for "array_col" which is repeated several times in this script):

from pyspark.sql import functions as F

df = df.withColumn(
    "out",
    F.filter(
        F.transform(
            F.flatten(F.transform(
                c:="array_col",
                lambda x: F.arrays_zip(F.array_repeat(x, F.size(c)), c)
            )),
            lambda x: F.array(x["0"], x[c])
        ),
        lambda x: x[0] < x[1]
    )
)
df.show(truncate=0)
# +------------+------------------------------------------------+
# |array_col   |out                                             |
# +------------+------------------------------------------------+
# |[0, 1]      |[[0, 1]]                                        |
# |[2, 3, 4]   |[[2, 3], [2, 4], [3, 4]]                        |
# |[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
# +------------+------------------------------------------------+

Alternative without walrus operator:

from pyspark.sql import functions as F

df = df.withColumn(
    "out",
    F.filter(
        F.transform(
            F.flatten(F.transform(
                "array_col",
                lambda x: F.arrays_zip(F.array_repeat(x, F.size("array_col")), "array_col")
            )),
            lambda x: F.array(x["0"], x["array_col"])
        ),
        lambda x: x[0] < x[1]
    )
)

Alternative for Spark 2.4+

from pyspark.sql import functions as F

df = df.withColumn(
    "out",
    F.expr("""
        filter(
            transform(
                flatten(transform(
                    array_col,
                    x -> arrays_zip(array_repeat(x, size(array_col)), array_col)
                )),
                x -> array(x["0"], x["array_col"])
            ),
            x -> x[0] < x[1]
        )
    """)
)
Answered By: ZygD