PySpark 3.3.0 – Aggregate sum with condition to avoid self join

Question:

Given the following dataframe structure:

+----------+-----+-------+
|  endPoint|count|outcome|
+----------+-----+-------+
|  getBooks|    3|success|
|  getBooks|    1|failure|
|getClasses|    0|success|
|getClasses|    4|failure|
+----------+-----+-------+

I’m trying to aggregate the data to get a failure rate. My resulting dataframe would look like this.

+----------+-----------+
|  endPoint|failureRate|
+----------+-----------+
|  getBooks|       0.25|
|getClasses|          1|
+----------+-----------+

I’m currently able to do this by creating a second dataframe which filters out the success rows, then join the two dataframes back together and create a new column that divides the sum of the failed count (for that endpoint) with the sum of the total count.

I’m trying to find a way to avoid creating a separate dataframe and then having to re-join them back together as it seems expensive and unnecessary. Is there a way to sum columns conditionally? I’ve been playing around with the syntax but am getting stuck.

If I could do something like this:

df.groupBy("endPoint").sum("count").when(outcome = "failure"))

that would be ideal but I’m having trouble with this and wonder if I’m missing something fundamental here.

Asked By: daniel9x

||

Answers:

You can use a when() within the sum aggregate.

data_sdf. 
    groupBy('end_point'). 
    agg(func.sum(func.when(func.col('outcome') == 'failure', func.col('count'))).alias('failure_count'),
        func.sum(func.when(func.col('outcome') == 'success', func.col('count'))).alias('success_count')
        ). 
    withColumn('failure_rate', 
               func.col('failure_count') / (func.col('failure_count') + func.col('success_count'))
               ). 
    show()

# +----------+-------------+-------------+------------+
# | end_point|failure_count|success_count|failure_rate|
# +----------+-------------+-------------+------------+
# |getClasses|            4|            0|         1.0|
# |  getBooks|            1|            3|        0.25|
# +----------+-------------+-------------+------------+
Answered By: samkart

You can do a pivot on your dataframe to get a wide version of your dataframe where the outcome strings are made into independent columns, containing the sum of the count column. From that dataframe, you calculate your failure rate:

import pyspark.sql.functions as F
# init example table
df = spark.createDataFrame(
    [
        ("getBooks", 3, "success"),
        ("getBooks", 1, "failure"),
        ("getClasses", 0, "success"),
        ("getClasses", 4, "failure"),
    ],
    ["endPoint", "count", "outcome"],
)
df.show()
df_pivot = df.groupBy("endPoint").pivot("outcome", ["success", "failure"]).sum("count")
df_pivot.show()
df_total = df_pivot.withColumn("total", F.col("success") + F.col("failure"))
df_total.show()
df_failure_rate = df_total.select("endPoint", (F.col("failure") / F.col("total")).alias("failureRate"))
df_failure_rate.show()

Output:

+----------+-----+-------+
|  endPoint|count|outcome|
+----------+-----+-------+
|  getBooks|    3|success|
|  getBooks|    1|failure|
|getClasses|    0|success|
|getClasses|    4|failure|
+----------+-----+-------+

+----------+-------+-------+
|  endPoint|success|failure|
+----------+-------+-------+
|getClasses|      0|      4|
|  getBooks|      3|      1|
+----------+-------+-------+

+----------+-------+-------+-----+
|  endPoint|success|failure|total|
+----------+-------+-------+-----+
|getClasses|      0|      4|    4|
|  getBooks|      3|      1|    4|
+----------+-------+-------+-----+

+----------+-----------+
|  endPoint|failureRate|
+----------+-----------+
|getClasses|        1.0|
|  getBooks|       0.25|
+----------+-----------+
Answered By: fskj

This is easily achieved by using Spark windows:

import pyspark.sql.functions as F
from pyspark.sql import SparkSession, Window

w = Window.partitionBy("endPoint")

(
    df.withColumn("total", F.sum("count").over(w))
    .withColumn("failureRate", F.col("count") / F.col("total"))
    .select("endPoint", "failureRate")
    .where(F.col("outcome") == "failure")
    .show()
)
Answered By: bzu

Has been run with sample data from fskj

But generally this give you the idea of what I want to do. It might be worth looking at the explain() of this plan vs the other ones to see what is more efficient. (Really it’s the only way to determine what’s’ better)

import pyspark.sql.functions as F
result = df. 
    groupBy('endPoint','outcome'). 
    agg( F.sum('outcome').alias("sum"),F.count("endPoint").alias("count") ). 
    where( F.col('outcome') != "success" ). 
    withColumn('failure_rate', 
               F.col('sum') / F.col('count') ). 
    select('endPoint','failure_rate')

explain for my solution: (I believe the to be the most efficient, as it uses predicate pushdown to remove "success" data early and therefore operates on less data. It also does not require a sort.)

== Physical Plan ==
*(2) HashAggregate(keys=[endPoint#78, outcome#80], functions=[sum(cast(outcome#80 as double)), count(endPoint#78)])
+- Exchange hashpartitioning(endPoint#78, outcome#80, 200)
   +- *(1) HashAggregate(keys=[endPoint#78, outcome#80], functions=[partial_sum(cast(outcome#80 as double)), partial_count(endPoint#78)])
      +- *(1) Project [endPoint#78, outcome#80]
         +- *(1) Filter (isnotnull(outcome#80) && NOT (outcome#80 = success))
            +- Scan ExistingRDD[endPoint#78,count#79L,outcome#80]

window explain

== Physical Plan ==
*(2) Project [endPoint#78, (cast(count#79L as double) / cast(total#1016L as double)) AS failureRate#1021]
+- *(2) Filter (isnotnull(outcome#80) && (outcome#80 = failure))
   +- Window [sum(count#79L) windowspecdefinition(endPoint#78, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS total#1016L], [endPoint#78]
      +- *(1) Sort [endPoint#78 ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(endPoint#78, 200)
            +- Scan ExistingRDD[endPoint#78,count#79L,outcome#80]

Pivot explain: (not as efficient)

df_failure_rate.explain()
== Physical Plan ==
HashAggregate(keys=[endPoint#78], functions=[pivotfirst(outcome#80, sum(`count`)#89L, success, failure, 0, 0)])
+- Exchange hashpartitioning(endPoint#78, 200)
   +- HashAggregate(keys=[endPoint#78], functions=[partial_pivotfirst(outcome#80, sum(`count`)#89L, success, failure, 0, 0)])
      +- *(2) HashAggregate(keys=[endPoint#78, outcome#80], functions=[sum(count#79L)])
         +- Exchange hashpartitioning(endPoint#78, outcome#80, 200)
            +- *(1) HashAggregate(keys=[endPoint#78, outcome#80], functions=[partial_sum(count#79L)])
               +- Scan ExistingRDD[endPoint#78,count#79L,outcome#80]

when answer explain: (very efficient if the data set is small, as it doesn’t filter out ‘success’ results.)

== Physical Plan ==
*(2) HashAggregate(keys=[end_point#1039], functions=[sum(CASE WHEN (outcome#80 = failure) THEN count#79L END), sum(CASE WHEN (outcome#80 = success) THEN count#79L END)])
+- Exchange hashpartitioning(end_point#1039, 200)
   +- *(1) HashAggregate(keys=[end_point#1039], functions=[partial_sum(CASE WHEN (outcome#80 = failure) THEN count#79L END), partial_sum(CASE WHEN (outcome#80 = success) THEN count#79L END)])
      +- *(1) Project [endPoint#78 AS end_point#1039, count#79L, outcome#80]
         +- Scan ExistingRDD[endPoint#78,count#79L,outcome#80]
Answered By: Matt Andruff

Here’s the most efficient (well it’s a ties the other solution I provided) which builds off the answer from @samkart.

So know it would just be which one you find easier to comprehend.

df. 
 filter(func.col("outcome")== "failure").
 groupBy('end_point'). 
     agg(func.sum(func.when(func.col('outcome') == 'failure', func.col('count'))).alias('failure_count'),
         func.sum(func.when(func.col('outcome') == 'success', func.col('count'))).alias('success_count')
         ). 
     withColumn('failure_rate', 
                func.col('failure_count') / (func.col('failure_count') + func.col('success_count'))
                ). 
 explain()

Explain

== Physical Plan ==
*(2) HashAggregate(keys=[end_point#1039], functions=[sum(CASE WHEN (outcome#80 = failure) THEN count#79L END), sum(CASE WHEN (outcome#80 = success) THEN count#79L END)])
+- Exchange hashpartitioning(end_point#1039, 200)
   +- *(1) HashAggregate(keys=[end_point#1039], functions=[partial_sum(CASE WHEN (outcome#80 = failure) THEN count#79L END), partial_sum(CASE WHEN (outcome#80 = success) THEN count#79L END)])
      +- *(1) Project [endPoint#78 AS end_point#1039, count#79L, outcome#80]
         +- *(1) Filter (isnotnull(outcome#80) && (outcome#80 = failure))
            +- Scan ExistingRDD[endPoint#78,count#79L,outcome#80]
Answered By: Matt Andruff