PySpark – Cumulative sum with limits

Question:

I have a dataframe as follows:

+-------+----------+-----+
|user_id|      date|valor|
+-------+----------+-----+
|      1|2022-01-01|    0|
|      1|2022-01-02|    0|
|      1|2022-01-03|    1|
|      1|2022-01-04|    1|
|      1|2022-01-05|    1|
|      1|2022-01-06|    0|
|      1|2022-01-07|    0|
|      1|2022-01-08|    0|
|      1|2022-01-09|    1|
|      1|2022-01-10|    1|
|      1|2022-01-11|    1|
|      1|2022-01-12|    0|
|      1|2022-01-13|    0|
|      1|2022-01-14|   -1|
|      1|2022-01-15|   -1|
|      1|2022-01-16|   -1|
|      1|2022-01-17|   -1|
|      1|2022-01-18|   -1|
|      1|2022-01-19|   -1|
|      1|2022-01-20|    0|
+-------+----------+-----+

The goal is to calculate a score for the user_id using valor as base, it will start from 3 and increase or decrease by 1 as it goes in the valor column. The main problem here is that my score can’t be under 1 and can’t be over 5, so the sum must always stay on the range and not lose the last value so I can compute it right. So what I expect is this:

+-------+----------+-----+-----+
|user_id|      date|valor|score|
+-------+----------+-----+-----+
|      1|2022-01-01|    0|    3|
|      1|2022-01-02|    0|    3|
|      1|2022-01-03|    1|    4|
|      1|2022-01-04|    1|    5|
|      1|2022-01-05|    1|    5|
|      1|2022-01-06|    0|    5|
|      1|2022-01-07|    0|    5|
|      1|2022-01-08|    0|    5|
|      1|2022-01-09|    1|    5|
|      1|2022-01-10|   -1|    4|
|      1|2022-01-11|   -1|    3|
|      1|2022-01-12|    0|    3|
|      1|2022-01-13|    0|    3|
|      1|2022-01-14|   -1|    2|
|      1|2022-01-15|   -1|    1|
|      1|2022-01-16|    1|    2|
|      1|2022-01-17|   -1|    1|
|      1|2022-01-18|   -1|    1|
|      1|2022-01-19|    1|    2|
|      1|2022-01-20|    0|    2|
+-------+----------+-----+-----+

So far, I’ve done a window to rank the column valor, so I can keep track of the quantity of increases or decreases in sequence and remove from valor the sequences larger then 4, but I don’t know how to keep the sum in valor_ in the range (1:5):

+-------+----------+----+-----+------+
|user_id|      date|rank|valor|valor_|
+-------+----------+----+-----+------+
|      1|2022-01-01|   0|    0|     0|
|      1|2022-01-02|   0|    0|     0|
|      1|2022-01-03|   1|    1|     1|
|      1|2022-01-04|   2|    1|     1|
|      1|2022-01-05|   3|    1|     1|
|      1|2022-01-06|   0|    0|     0|
|      1|2022-01-07|   0|    0|     0|
|      1|2022-01-08|   0|    0|     0|
|      1|2022-01-09|   1|    1|     1|
|      1|2022-01-10|   2|    1|     1|
|      1|2022-01-11|   3|    1|     1|
|      1|2022-01-12|   0|    0|     0|
|      1|2022-01-13|   0|    0|     0|
|      1|2022-01-14|   1|   -1|    -1|
|      1|2022-01-15|   2|   -1|    -1|
|      1|2022-01-16|   3|   -1|    -1|
|      1|2022-01-17|   4|   -1|    -1|
|      1|2022-01-18|   5|   -1|     0|
|      1|2022-01-19|   6|   -1|     0|

As you can see, the result here is not what I expected:

+-------+----------+----+-----+------+-----+
|user_id|      date|rank|valor|valor_|score|
+-------+----------+----+-----+------+-----+
|      1|2022-01-01|   0|    0|     0|    3|
|      1|2022-01-02|   0|    0|     0|    3|
|      1|2022-01-03|   1|    1|     1|    4|
|      1|2022-01-04|   2|    1|     1|    5|
|      1|2022-01-05|   3|    1|     1|    6|
|      1|2022-01-06|   0|    0|     0|    6|
|      1|2022-01-07|   0|    0|     0|    6|
|      1|2022-01-08|   0|    0|     0|    6|
|      1|2022-01-09|   1|    1|     1|    7|
|      1|2022-01-10|   2|    1|     1|    8|
|      1|2022-01-11|   3|    1|     1|    9|
|      1|2022-01-12|   0|    0|     0|    9|
|      1|2022-01-13|   0|    0|     0|    9|
|      1|2022-01-14|   1|   -1|    -1|    8|
|      1|2022-01-15|   2|   -1|    -1|    7|
|      1|2022-01-16|   3|   -1|    -1|    6|
|      1|2022-01-17|   4|   -1|    -1|    5|
|      1|2022-01-18|   5|   -1|     0|    5|
|      1|2022-01-19|   6|   -1|     0|    5|
|      1|2022-01-20|   0|    0|     0|    5|
Asked By: Hiago Reis

||

Answers:

tl;dr – complex approach similar to this – consider this as last resort due to its complexity

A python function can keep track of the previous cumulative sum value. The said python function can be used with flatMapValues() to process the data.

Consider the following input data

data1_ls = [(1, k.strftime('%Y-%m-%d'), random.randint(-1, 1)) for k in pd.date_range(pd.to_datetime('2022-01-01'), pd.to_datetime('2022-01-20'))]
data2_ls = [(2, k.strftime('%Y-%m-%d'), random.randint(-1, 1)) for k in pd.date_range(pd.to_datetime('2022-04-01'), pd.to_datetime('2022-04-30'))]

data1_sdf = spark.sparkContext.parallelize(data1_ls).toDF(['user', 'dt', 'valor']). 
    withColumn('dt', func.col('dt').cast('date'))

data2_sdf = spark.sparkContext.parallelize(data2_ls).toDF(['user', 'dt', 'valor']). 
    withColumn('dt', func.col('dt').cast('date'))

data_sdf = data1_sdf.unionByName(data2_sdf)

# +----+----------+-----+
# |user|        dt|valor|
# +----+----------+-----+
# |   1|2022-01-01|    1|
# |   1|2022-01-02|   -1|
# |   1|2022-01-03|    0|
# |   1|2022-01-04|    1|
# |   1|2022-01-05|    0|
# +----+----------+-----+

We can write a python function that takes the sum and keeps track of it. This function should be shipped to all executors for optimum resource usage.

def cumsum_in_range(groupedRows, initial_value=3):
    """
    """

    res = []
    frstRec = True
    initVal = initial_value

    for row in groupedRows:
        if frstRec:
            # data starts from a static value
            frstRec = False
            cumsum = initVal + row.valor
        else:
            cumsum = prev_cumsum + row.valor

            if cumsum > 5:
                cumsum = 5
            elif cumsum < 1:
                cumsum = 1
            
        prev_cumsum = cumsum  # keeping track of the latest sum for next iteration

        res.append([item for item in row] + [cumsum])
    
    return res

To use the function to process, we’ll use flatMapValues() and groupBy(). The groupBy() partitions the data based on the column provided. We’ll also need the data order by the date field for the cumulative sum. So, a sorted() will be used and the date field will be passed as key.

# run the python function and keep only the resulting values
res_vals = data_sdf.rdd. 
    groupBy(lambda gk: gk.user). 
    flatMapValues(lambda r: cumsum_in_range(sorted(r, key=lambda ok: ok.dt))). 
    values()

# create schema for the new column in previous dataframe
data_schema = data_sdf.withColumn('dropme', func.lit(None).cast('int')). 
    drop('dropme'). 
    schema. 
    add('cumsum', 'integer')

# create a dataframe with the new values
res_sdf = spark.createDataFrame(res_vals, data_schema)

The res_sdf dataframe will have the cumulative sum column created for each user, based on the python function defined above.

res_sdf. 
    filter(func.col('user') == 1). 
    orderBy(['user', 'dt']). 
    show()

# +----+----------+-----+------+
# |user|        dt|valor|cumsum|
# +----+----------+-----+------+
# |   1|2022-01-01|    1|     4|
# |   1|2022-01-02|   -1|     3|
# |   1|2022-01-03|    0|     3|
# |   1|2022-01-04|    1|     4|
# |   1|2022-01-05|    0|     4|
# |   1|2022-01-06|    1|     5|
# |   1|2022-01-07|    0|     5|
# |   1|2022-01-08|    1|     5|
# |   1|2022-01-09|    0|     5|
# |   1|2022-01-10|   -1|     4|
# |   1|2022-01-11|   -1|     3|
# |   1|2022-01-12|   -1|     2|
# |   1|2022-01-13|    1|     3|
# |   1|2022-01-14|   -1|     2|
# |   1|2022-01-15|    1|     3|
# |   1|2022-01-16|   -1|     2|
# |   1|2022-01-17|    0|     2|
# |   1|2022-01-18|    1|     3|
# |   1|2022-01-19|    0|     3|
# |   1|2022-01-20|   -1|     2|
# +----+----------+-----+------+

res_sdf. 
    filter(func.col('user') == 2). 
    orderBy(['user', 'dt']). 
    show()

# +----+----------+-----+------+
# |user|        dt|valor|cumsum|
# +----+----------+-----+------+
# |   2|2022-04-01|   -1|     2|
# |   2|2022-04-02|    0|     2|
# |   2|2022-04-03|    1|     3|
# |   2|2022-04-04|   -1|     2|
# |   2|2022-04-05|    1|     3|
# |   2|2022-04-06|    0|     3|
# |   2|2022-04-07|    1|     4|
# |   2|2022-04-08|   -1|     3|
# |   2|2022-04-09|    0|     3|
# |   2|2022-04-10|    0|     3|
# |   2|2022-04-11|   -1|     2|
# |   2|2022-04-12|    1|     3|
# |   2|2022-04-13|    0|     3|
# |   2|2022-04-14|    0|     3|
# |   2|2022-04-15|    1|     4|
# |   2|2022-04-16|   -1|     3|
# |   2|2022-04-17|    0|     3|
# |   2|2022-04-18|    0|     3|
# |   2|2022-04-19|    1|     4|
# |   2|2022-04-20|    1|     5|
# +----+----------+-----+------+
# only showing top 20 rows
Answered By: samkart

In such cases, we usually think of window functions to do a calculation going from one row to next. But this case is different, because the window should kind of keep track of itself. So window cannot help.

Main idea. Instead of operating with rows, one can do the work with grouped/aggregated arrays. In this case, it would work very well, because we do have a key to use in groupBy, so the table will be divided into chunks of data, so the calculations will be parallelized.

Input:

from pyspark.sql import functions as F
df = spark.createDataFrame(
    [(1, '2022-01-01',  0),
     (1, '2022-01-02',  0),
     (1, '2022-01-03',  1),
     (1, '2022-01-04',  1),
     (1, '2022-01-05',  1),
     (1, '2022-01-06',  0),
     (1, '2022-01-07',  0),
     (1, '2022-01-08',  0),
     (1, '2022-01-09',  1),
     (1, '2022-01-10',  1),
     (1, '2022-01-11',  1),
     (1, '2022-01-12',  0),
     (1, '2022-01-13',  0),
     (1, '2022-01-14', -1),
     (1, '2022-01-15', -1),
     (1, '2022-01-16', -1),
     (1, '2022-01-17', -1),
     (1, '2022-01-18', -1),
     (1, '2022-01-19', -1),
     (1, '2022-01-20',  0)],
    ['user_id', 'date', 'valor'])

Script:

df = df.groupBy('user_id').agg(
    F.aggregate(
        F.array_sort(F.collect_list(F.struct('date', 'valor'))),
        F.expr("array(struct(cast(null as string) date, 0L valor, 3L cum))"),
        lambda acc, x: F.array_union(
            acc,
            F.array(x.withField(
                'cum',
                F.greatest(F.lit(1), F.least(F.lit(5), x['valor'] + F.element_at(acc, -1)['cum']))
            ))
        )
    ).alias("a")
)
df = df.selectExpr("user_id", "inline(slice(a, 2, size(a)))")

df.show()
# +-------+----------+-----+---+
# |user_id|      date|valor|cum|
# +-------+----------+-----+---+
# |      1|2022-01-01|    0|  3|
# |      1|2022-01-02|    0|  3|
# |      1|2022-01-03|    1|  4|
# |      1|2022-01-04|    1|  5|
# |      1|2022-01-05|    1|  5|
# |      1|2022-01-06|    0|  5|
# |      1|2022-01-07|    0|  5|
# |      1|2022-01-08|    0|  5|
# |      1|2022-01-09|    1|  5|
# |      1|2022-01-10|    1|  5|
# |      1|2022-01-11|    1|  5|
# |      1|2022-01-12|    0|  5|
# |      1|2022-01-13|    0|  5|
# |      1|2022-01-14|   -1|  4|
# |      1|2022-01-15|   -1|  3|
# |      1|2022-01-16|   -1|  2|
# |      1|2022-01-17|   -1|  1|
# |      1|2022-01-18|   -1|  1|
# |      1|2022-01-19|   -1|  1|
# |      1|2022-01-20|    0|  1|
# +-------+----------+-----+---+

Explanation

Groups are created based on "user_id". The aggregation for these groups lies in this line:

F.array_sort(F.collect_list(F.struct('date', 'valor')))

This creates arrays (collect_list) for every "user_id". These arrays contain structs of 2 fields: date and value.

+-------+-----------------------------------------------+
|user_id|a                                              |
+-------+-----------------------------------------------+
|1      |[{2022-01-01, 0}, {2022-01-02, 0}, {...} ... ] |
+-------+-----------------------------------------------+

array_sort is used to make sure all the structs inside are sorted, because other steps will depend on it.

All the rest what’s inside agg is for transforming the result of the above aggregation.

The main part in the code is aggregate. It takes an array, "loops" through every element and returns one value (in our case, this value is made to be array too). It works like this… You take the initial value (array(struct(cast(null as string) date, 0L valor, 3L cum)) and merge it with the first element in the array using the provided function (lambda). The result is then used in place of initial value for the next run. You do the merge again, but with the following element in the array. And so on.

In this case, the lambda function performs array_union, which makes a union of arrays having identic schemas.

  1. We take the initial value (array of structs) as acc variable
    [{null, 0, 3}]
    (it’s already ready to be used in array_union)

  2. take the first element inside ‘a’ column’s array (i.e. ) as x variable
    {2022-01-01, 0}
    (it’s a struct, so the schema is not the same with acc (array of structs), so some processing is needed, and also the calculation needs to be done at this step, as we have access to both of the variables at this point)

  3. we’ll create the array of structs by enclosing the x struct inside F.array(); also, we’ll have to add one more field to the struct, as x struct currently has just 2 fields
    F.array(x.withField('cum', ...))

  4. inside the .withField() we have to provide the expression for the field

    F.greatest(
        F.lit(1),
        F.least(
            F.lit(5),
            x['valor'] + F.element_at(acc, -1)['cum']
        )
    )
    

    element_at(acc, -1) takes the last struct of acc array
    ['cum'] takes the field ‘cum’ from the struct
    x['valor'] + adds ‘valor’ field from the x struct
    F.least() assures that the max value in ‘cum’ will stay 5 (takes the min value from the new ‘cum’ and 5)
    F.greatest() assures that the min value in ‘cum’ will stay 1

  5. both acc and the newly created array of structs now have identic schemas and proper data, so they can be unionized
    array_union
    the result is now being assigned to acc variable, while x variable gets assigned the next value from the ‘a’ array.
    The process continues from step 3.

Finally, the result of aggregate looks like
[{null, 0, 3}, {2022-01-01, 0, 3}, {2022-01-02, 0, 3}, {2022-01-03, 1, 4}, {...} ... ]
The first element is removed using slice(..., 2, size(a))

inline is used to explode the array of structs.


Note. It’s important to create the initial value of aggregate such that it would contain proper schema (column/field names and types):

F.expr("array(struct(cast(null as string) date, 0L valor, 3L cum))")

Those L letters tell that 0 and 3 are of bigint (long) data type. (sql-ref-literals)

The same could have been written like this:

F.expr("array(struct(null, 0, 3))").cast('array<struct<date:string,valor:bigint,cum:bigint>>')
Answered By: ZygD