PySpark – Cumulative sum with limits
Question:
I have a dataframe as follows:
+-------+----------+-----+
|user_id| date|valor|
+-------+----------+-----+
| 1|2022-01-01| 0|
| 1|2022-01-02| 0|
| 1|2022-01-03| 1|
| 1|2022-01-04| 1|
| 1|2022-01-05| 1|
| 1|2022-01-06| 0|
| 1|2022-01-07| 0|
| 1|2022-01-08| 0|
| 1|2022-01-09| 1|
| 1|2022-01-10| 1|
| 1|2022-01-11| 1|
| 1|2022-01-12| 0|
| 1|2022-01-13| 0|
| 1|2022-01-14| -1|
| 1|2022-01-15| -1|
| 1|2022-01-16| -1|
| 1|2022-01-17| -1|
| 1|2022-01-18| -1|
| 1|2022-01-19| -1|
| 1|2022-01-20| 0|
+-------+----------+-----+
The goal is to calculate a score for the user_id using valor as base, it will start from 3 and increase or decrease by 1 as it goes in the valor column. The main problem here is that my score can’t be under 1 and can’t be over 5, so the sum must always stay on the range and not lose the last value so I can compute it right. So what I expect is this:
+-------+----------+-----+-----+
|user_id| date|valor|score|
+-------+----------+-----+-----+
| 1|2022-01-01| 0| 3|
| 1|2022-01-02| 0| 3|
| 1|2022-01-03| 1| 4|
| 1|2022-01-04| 1| 5|
| 1|2022-01-05| 1| 5|
| 1|2022-01-06| 0| 5|
| 1|2022-01-07| 0| 5|
| 1|2022-01-08| 0| 5|
| 1|2022-01-09| 1| 5|
| 1|2022-01-10| -1| 4|
| 1|2022-01-11| -1| 3|
| 1|2022-01-12| 0| 3|
| 1|2022-01-13| 0| 3|
| 1|2022-01-14| -1| 2|
| 1|2022-01-15| -1| 1|
| 1|2022-01-16| 1| 2|
| 1|2022-01-17| -1| 1|
| 1|2022-01-18| -1| 1|
| 1|2022-01-19| 1| 2|
| 1|2022-01-20| 0| 2|
+-------+----------+-----+-----+
So far, I’ve done a window to rank the column valor, so I can keep track of the quantity of increases or decreases in sequence and remove from valor the sequences larger then 4, but I don’t know how to keep the sum in valor_ in the range (1:5):
+-------+----------+----+-----+------+
|user_id| date|rank|valor|valor_|
+-------+----------+----+-----+------+
| 1|2022-01-01| 0| 0| 0|
| 1|2022-01-02| 0| 0| 0|
| 1|2022-01-03| 1| 1| 1|
| 1|2022-01-04| 2| 1| 1|
| 1|2022-01-05| 3| 1| 1|
| 1|2022-01-06| 0| 0| 0|
| 1|2022-01-07| 0| 0| 0|
| 1|2022-01-08| 0| 0| 0|
| 1|2022-01-09| 1| 1| 1|
| 1|2022-01-10| 2| 1| 1|
| 1|2022-01-11| 3| 1| 1|
| 1|2022-01-12| 0| 0| 0|
| 1|2022-01-13| 0| 0| 0|
| 1|2022-01-14| 1| -1| -1|
| 1|2022-01-15| 2| -1| -1|
| 1|2022-01-16| 3| -1| -1|
| 1|2022-01-17| 4| -1| -1|
| 1|2022-01-18| 5| -1| 0|
| 1|2022-01-19| 6| -1| 0|
As you can see, the result here is not what I expected:
+-------+----------+----+-----+------+-----+
|user_id| date|rank|valor|valor_|score|
+-------+----------+----+-----+------+-----+
| 1|2022-01-01| 0| 0| 0| 3|
| 1|2022-01-02| 0| 0| 0| 3|
| 1|2022-01-03| 1| 1| 1| 4|
| 1|2022-01-04| 2| 1| 1| 5|
| 1|2022-01-05| 3| 1| 1| 6|
| 1|2022-01-06| 0| 0| 0| 6|
| 1|2022-01-07| 0| 0| 0| 6|
| 1|2022-01-08| 0| 0| 0| 6|
| 1|2022-01-09| 1| 1| 1| 7|
| 1|2022-01-10| 2| 1| 1| 8|
| 1|2022-01-11| 3| 1| 1| 9|
| 1|2022-01-12| 0| 0| 0| 9|
| 1|2022-01-13| 0| 0| 0| 9|
| 1|2022-01-14| 1| -1| -1| 8|
| 1|2022-01-15| 2| -1| -1| 7|
| 1|2022-01-16| 3| -1| -1| 6|
| 1|2022-01-17| 4| -1| -1| 5|
| 1|2022-01-18| 5| -1| 0| 5|
| 1|2022-01-19| 6| -1| 0| 5|
| 1|2022-01-20| 0| 0| 0| 5|
Answers:
tl;dr – complex approach similar to this – consider this as last resort due to its complexity
A python function can keep track of the previous cumulative sum value. The said python function can be used with flatMapValues()
to process the data.
Consider the following input data
data1_ls = [(1, k.strftime('%Y-%m-%d'), random.randint(-1, 1)) for k in pd.date_range(pd.to_datetime('2022-01-01'), pd.to_datetime('2022-01-20'))]
data2_ls = [(2, k.strftime('%Y-%m-%d'), random.randint(-1, 1)) for k in pd.date_range(pd.to_datetime('2022-04-01'), pd.to_datetime('2022-04-30'))]
data1_sdf = spark.sparkContext.parallelize(data1_ls).toDF(['user', 'dt', 'valor']).
withColumn('dt', func.col('dt').cast('date'))
data2_sdf = spark.sparkContext.parallelize(data2_ls).toDF(['user', 'dt', 'valor']).
withColumn('dt', func.col('dt').cast('date'))
data_sdf = data1_sdf.unionByName(data2_sdf)
# +----+----------+-----+
# |user| dt|valor|
# +----+----------+-----+
# | 1|2022-01-01| 1|
# | 1|2022-01-02| -1|
# | 1|2022-01-03| 0|
# | 1|2022-01-04| 1|
# | 1|2022-01-05| 0|
# +----+----------+-----+
We can write a python function that takes the sum and keeps track of it. This function should be shipped to all executors for optimum resource usage.
def cumsum_in_range(groupedRows, initial_value=3):
"""
"""
res = []
frstRec = True
initVal = initial_value
for row in groupedRows:
if frstRec:
# data starts from a static value
frstRec = False
cumsum = initVal + row.valor
else:
cumsum = prev_cumsum + row.valor
if cumsum > 5:
cumsum = 5
elif cumsum < 1:
cumsum = 1
prev_cumsum = cumsum # keeping track of the latest sum for next iteration
res.append([item for item in row] + [cumsum])
return res
To use the function to process, we’ll use flatMapValues()
and groupBy()
. The groupBy()
partitions the data based on the column provided. We’ll also need the data order by the date field for the cumulative sum. So, a sorted()
will be used and the date field will be passed as key.
# run the python function and keep only the resulting values
res_vals = data_sdf.rdd.
groupBy(lambda gk: gk.user).
flatMapValues(lambda r: cumsum_in_range(sorted(r, key=lambda ok: ok.dt))).
values()
# create schema for the new column in previous dataframe
data_schema = data_sdf.withColumn('dropme', func.lit(None).cast('int')).
drop('dropme').
schema.
add('cumsum', 'integer')
# create a dataframe with the new values
res_sdf = spark.createDataFrame(res_vals, data_schema)
The res_sdf
dataframe will have the cumulative sum column created for each user, based on the python function defined above.
res_sdf.
filter(func.col('user') == 1).
orderBy(['user', 'dt']).
show()
# +----+----------+-----+------+
# |user| dt|valor|cumsum|
# +----+----------+-----+------+
# | 1|2022-01-01| 1| 4|
# | 1|2022-01-02| -1| 3|
# | 1|2022-01-03| 0| 3|
# | 1|2022-01-04| 1| 4|
# | 1|2022-01-05| 0| 4|
# | 1|2022-01-06| 1| 5|
# | 1|2022-01-07| 0| 5|
# | 1|2022-01-08| 1| 5|
# | 1|2022-01-09| 0| 5|
# | 1|2022-01-10| -1| 4|
# | 1|2022-01-11| -1| 3|
# | 1|2022-01-12| -1| 2|
# | 1|2022-01-13| 1| 3|
# | 1|2022-01-14| -1| 2|
# | 1|2022-01-15| 1| 3|
# | 1|2022-01-16| -1| 2|
# | 1|2022-01-17| 0| 2|
# | 1|2022-01-18| 1| 3|
# | 1|2022-01-19| 0| 3|
# | 1|2022-01-20| -1| 2|
# +----+----------+-----+------+
res_sdf.
filter(func.col('user') == 2).
orderBy(['user', 'dt']).
show()
# +----+----------+-----+------+
# |user| dt|valor|cumsum|
# +----+----------+-----+------+
# | 2|2022-04-01| -1| 2|
# | 2|2022-04-02| 0| 2|
# | 2|2022-04-03| 1| 3|
# | 2|2022-04-04| -1| 2|
# | 2|2022-04-05| 1| 3|
# | 2|2022-04-06| 0| 3|
# | 2|2022-04-07| 1| 4|
# | 2|2022-04-08| -1| 3|
# | 2|2022-04-09| 0| 3|
# | 2|2022-04-10| 0| 3|
# | 2|2022-04-11| -1| 2|
# | 2|2022-04-12| 1| 3|
# | 2|2022-04-13| 0| 3|
# | 2|2022-04-14| 0| 3|
# | 2|2022-04-15| 1| 4|
# | 2|2022-04-16| -1| 3|
# | 2|2022-04-17| 0| 3|
# | 2|2022-04-18| 0| 3|
# | 2|2022-04-19| 1| 4|
# | 2|2022-04-20| 1| 5|
# +----+----------+-----+------+
# only showing top 20 rows
In such cases, we usually think of window functions to do a calculation going from one row to next. But this case is different, because the window should kind of keep track of itself. So window cannot help.
Main idea. Instead of operating with rows, one can do the work with grouped/aggregated arrays. In this case, it would work very well, because we do have a key to use in groupBy
, so the table will be divided into chunks of data, so the calculations will be parallelized.
Input:
from pyspark.sql import functions as F
df = spark.createDataFrame(
[(1, '2022-01-01', 0),
(1, '2022-01-02', 0),
(1, '2022-01-03', 1),
(1, '2022-01-04', 1),
(1, '2022-01-05', 1),
(1, '2022-01-06', 0),
(1, '2022-01-07', 0),
(1, '2022-01-08', 0),
(1, '2022-01-09', 1),
(1, '2022-01-10', 1),
(1, '2022-01-11', 1),
(1, '2022-01-12', 0),
(1, '2022-01-13', 0),
(1, '2022-01-14', -1),
(1, '2022-01-15', -1),
(1, '2022-01-16', -1),
(1, '2022-01-17', -1),
(1, '2022-01-18', -1),
(1, '2022-01-19', -1),
(1, '2022-01-20', 0)],
['user_id', 'date', 'valor'])
Script:
df = df.groupBy('user_id').agg(
F.aggregate(
F.array_sort(F.collect_list(F.struct('date', 'valor'))),
F.expr("array(struct(cast(null as string) date, 0L valor, 3L cum))"),
lambda acc, x: F.array_union(
acc,
F.array(x.withField(
'cum',
F.greatest(F.lit(1), F.least(F.lit(5), x['valor'] + F.element_at(acc, -1)['cum']))
))
)
).alias("a")
)
df = df.selectExpr("user_id", "inline(slice(a, 2, size(a)))")
df.show()
# +-------+----------+-----+---+
# |user_id| date|valor|cum|
# +-------+----------+-----+---+
# | 1|2022-01-01| 0| 3|
# | 1|2022-01-02| 0| 3|
# | 1|2022-01-03| 1| 4|
# | 1|2022-01-04| 1| 5|
# | 1|2022-01-05| 1| 5|
# | 1|2022-01-06| 0| 5|
# | 1|2022-01-07| 0| 5|
# | 1|2022-01-08| 0| 5|
# | 1|2022-01-09| 1| 5|
# | 1|2022-01-10| 1| 5|
# | 1|2022-01-11| 1| 5|
# | 1|2022-01-12| 0| 5|
# | 1|2022-01-13| 0| 5|
# | 1|2022-01-14| -1| 4|
# | 1|2022-01-15| -1| 3|
# | 1|2022-01-16| -1| 2|
# | 1|2022-01-17| -1| 1|
# | 1|2022-01-18| -1| 1|
# | 1|2022-01-19| -1| 1|
# | 1|2022-01-20| 0| 1|
# +-------+----------+-----+---+
Explanation
Groups are created based on "user_id". The aggregation for these groups lies in this line:
F.array_sort(F.collect_list(F.struct('date', 'valor')))
This creates arrays (collect_list
) for every "user_id". These arrays contain structs of 2 fields: date and value.
+-------+-----------------------------------------------+
|user_id|a |
+-------+-----------------------------------------------+
|1 |[{2022-01-01, 0}, {2022-01-02, 0}, {...} ... ] |
+-------+-----------------------------------------------+
array_sort
is used to make sure all the structs inside are sorted, because other steps will depend on it.
All the rest what’s inside agg
is for transforming the result of the above aggregation.
The main part in the code is aggregate
. It takes an array, "loops" through every element and returns one value (in our case, this value is made to be array too). It works like this… You take the initial value (array(struct(cast(null as string) date, 0L valor, 3L cum))
and merge it with the first element in the array using the provided function (lambda
). The result is then used in place of initial value for the next run. You do the merge again, but with the following element in the array. And so on.
In this case, the lambda
function performs array_union
, which makes a union of arrays having identic schemas.
-
We take the initial value (array of structs) as acc
variable
[{null, 0, 3}]
(it’s already ready to be used in array_union
)
-
take the first element inside ‘a’ column’s array (i.e. ) as x
variable
{2022-01-01, 0}
(it’s a struct, so the schema is not the same with acc
(array of structs), so some processing is needed, and also the calculation needs to be done at this step, as we have access to both of the variables at this point)
-
we’ll create the array of structs by enclosing the x
struct inside F.array()
; also, we’ll have to add one more field to the struct, as x
struct currently has just 2 fields
F.array(x.withField('cum', ...))
-
inside the .withField()
we have to provide the expression for the field
F.greatest(
F.lit(1),
F.least(
F.lit(5),
x['valor'] + F.element_at(acc, -1)['cum']
)
)
element_at(acc, -1)
takes the last struct of acc
array
['cum']
takes the field ‘cum’ from the struct
x['valor'] +
adds ‘valor’ field from the x
struct
F.least()
assures that the max value in ‘cum’ will stay 5 (takes the min value from the new ‘cum’ and 5)
F.greatest()
assures that the min value in ‘cum’ will stay 1
-
both acc
and the newly created array of structs now have identic schemas and proper data, so they can be unionized
array_union
the result is now being assigned to acc
variable, while x
variable gets assigned the next value from the ‘a’ array.
The process continues from step 3.
Finally, the result of aggregate
looks like
[{null, 0, 3}, {2022-01-01, 0, 3}, {2022-01-02, 0, 3}, {2022-01-03, 1, 4}, {...} ... ]
The first element is removed using slice(..., 2, size(a))
inline
is used to explode the array of structs.
Note. It’s important to create the initial value of aggregate
such that it would contain proper schema (column/field names and types):
F.expr("array(struct(cast(null as string) date, 0L valor, 3L cum))")
Those L
letters tell that 0
and 3
are of bigint (long) data type. (sql-ref-literals)
The same could have been written like this:
F.expr("array(struct(null, 0, 3))").cast('array<struct<date:string,valor:bigint,cum:bigint>>')
I have a dataframe as follows:
+-------+----------+-----+
|user_id| date|valor|
+-------+----------+-----+
| 1|2022-01-01| 0|
| 1|2022-01-02| 0|
| 1|2022-01-03| 1|
| 1|2022-01-04| 1|
| 1|2022-01-05| 1|
| 1|2022-01-06| 0|
| 1|2022-01-07| 0|
| 1|2022-01-08| 0|
| 1|2022-01-09| 1|
| 1|2022-01-10| 1|
| 1|2022-01-11| 1|
| 1|2022-01-12| 0|
| 1|2022-01-13| 0|
| 1|2022-01-14| -1|
| 1|2022-01-15| -1|
| 1|2022-01-16| -1|
| 1|2022-01-17| -1|
| 1|2022-01-18| -1|
| 1|2022-01-19| -1|
| 1|2022-01-20| 0|
+-------+----------+-----+
The goal is to calculate a score for the user_id using valor as base, it will start from 3 and increase or decrease by 1 as it goes in the valor column. The main problem here is that my score can’t be under 1 and can’t be over 5, so the sum must always stay on the range and not lose the last value so I can compute it right. So what I expect is this:
+-------+----------+-----+-----+
|user_id| date|valor|score|
+-------+----------+-----+-----+
| 1|2022-01-01| 0| 3|
| 1|2022-01-02| 0| 3|
| 1|2022-01-03| 1| 4|
| 1|2022-01-04| 1| 5|
| 1|2022-01-05| 1| 5|
| 1|2022-01-06| 0| 5|
| 1|2022-01-07| 0| 5|
| 1|2022-01-08| 0| 5|
| 1|2022-01-09| 1| 5|
| 1|2022-01-10| -1| 4|
| 1|2022-01-11| -1| 3|
| 1|2022-01-12| 0| 3|
| 1|2022-01-13| 0| 3|
| 1|2022-01-14| -1| 2|
| 1|2022-01-15| -1| 1|
| 1|2022-01-16| 1| 2|
| 1|2022-01-17| -1| 1|
| 1|2022-01-18| -1| 1|
| 1|2022-01-19| 1| 2|
| 1|2022-01-20| 0| 2|
+-------+----------+-----+-----+
So far, I’ve done a window to rank the column valor, so I can keep track of the quantity of increases or decreases in sequence and remove from valor the sequences larger then 4, but I don’t know how to keep the sum in valor_ in the range (1:5):
+-------+----------+----+-----+------+
|user_id| date|rank|valor|valor_|
+-------+----------+----+-----+------+
| 1|2022-01-01| 0| 0| 0|
| 1|2022-01-02| 0| 0| 0|
| 1|2022-01-03| 1| 1| 1|
| 1|2022-01-04| 2| 1| 1|
| 1|2022-01-05| 3| 1| 1|
| 1|2022-01-06| 0| 0| 0|
| 1|2022-01-07| 0| 0| 0|
| 1|2022-01-08| 0| 0| 0|
| 1|2022-01-09| 1| 1| 1|
| 1|2022-01-10| 2| 1| 1|
| 1|2022-01-11| 3| 1| 1|
| 1|2022-01-12| 0| 0| 0|
| 1|2022-01-13| 0| 0| 0|
| 1|2022-01-14| 1| -1| -1|
| 1|2022-01-15| 2| -1| -1|
| 1|2022-01-16| 3| -1| -1|
| 1|2022-01-17| 4| -1| -1|
| 1|2022-01-18| 5| -1| 0|
| 1|2022-01-19| 6| -1| 0|
As you can see, the result here is not what I expected:
+-------+----------+----+-----+------+-----+
|user_id| date|rank|valor|valor_|score|
+-------+----------+----+-----+------+-----+
| 1|2022-01-01| 0| 0| 0| 3|
| 1|2022-01-02| 0| 0| 0| 3|
| 1|2022-01-03| 1| 1| 1| 4|
| 1|2022-01-04| 2| 1| 1| 5|
| 1|2022-01-05| 3| 1| 1| 6|
| 1|2022-01-06| 0| 0| 0| 6|
| 1|2022-01-07| 0| 0| 0| 6|
| 1|2022-01-08| 0| 0| 0| 6|
| 1|2022-01-09| 1| 1| 1| 7|
| 1|2022-01-10| 2| 1| 1| 8|
| 1|2022-01-11| 3| 1| 1| 9|
| 1|2022-01-12| 0| 0| 0| 9|
| 1|2022-01-13| 0| 0| 0| 9|
| 1|2022-01-14| 1| -1| -1| 8|
| 1|2022-01-15| 2| -1| -1| 7|
| 1|2022-01-16| 3| -1| -1| 6|
| 1|2022-01-17| 4| -1| -1| 5|
| 1|2022-01-18| 5| -1| 0| 5|
| 1|2022-01-19| 6| -1| 0| 5|
| 1|2022-01-20| 0| 0| 0| 5|
tl;dr – complex approach similar to this – consider this as last resort due to its complexity
A python function can keep track of the previous cumulative sum value. The said python function can be used with flatMapValues()
to process the data.
Consider the following input data
data1_ls = [(1, k.strftime('%Y-%m-%d'), random.randint(-1, 1)) for k in pd.date_range(pd.to_datetime('2022-01-01'), pd.to_datetime('2022-01-20'))]
data2_ls = [(2, k.strftime('%Y-%m-%d'), random.randint(-1, 1)) for k in pd.date_range(pd.to_datetime('2022-04-01'), pd.to_datetime('2022-04-30'))]
data1_sdf = spark.sparkContext.parallelize(data1_ls).toDF(['user', 'dt', 'valor']).
withColumn('dt', func.col('dt').cast('date'))
data2_sdf = spark.sparkContext.parallelize(data2_ls).toDF(['user', 'dt', 'valor']).
withColumn('dt', func.col('dt').cast('date'))
data_sdf = data1_sdf.unionByName(data2_sdf)
# +----+----------+-----+
# |user| dt|valor|
# +----+----------+-----+
# | 1|2022-01-01| 1|
# | 1|2022-01-02| -1|
# | 1|2022-01-03| 0|
# | 1|2022-01-04| 1|
# | 1|2022-01-05| 0|
# +----+----------+-----+
We can write a python function that takes the sum and keeps track of it. This function should be shipped to all executors for optimum resource usage.
def cumsum_in_range(groupedRows, initial_value=3):
"""
"""
res = []
frstRec = True
initVal = initial_value
for row in groupedRows:
if frstRec:
# data starts from a static value
frstRec = False
cumsum = initVal + row.valor
else:
cumsum = prev_cumsum + row.valor
if cumsum > 5:
cumsum = 5
elif cumsum < 1:
cumsum = 1
prev_cumsum = cumsum # keeping track of the latest sum for next iteration
res.append([item for item in row] + [cumsum])
return res
To use the function to process, we’ll use flatMapValues()
and groupBy()
. The groupBy()
partitions the data based on the column provided. We’ll also need the data order by the date field for the cumulative sum. So, a sorted()
will be used and the date field will be passed as key.
# run the python function and keep only the resulting values
res_vals = data_sdf.rdd.
groupBy(lambda gk: gk.user).
flatMapValues(lambda r: cumsum_in_range(sorted(r, key=lambda ok: ok.dt))).
values()
# create schema for the new column in previous dataframe
data_schema = data_sdf.withColumn('dropme', func.lit(None).cast('int')).
drop('dropme').
schema.
add('cumsum', 'integer')
# create a dataframe with the new values
res_sdf = spark.createDataFrame(res_vals, data_schema)
The res_sdf
dataframe will have the cumulative sum column created for each user, based on the python function defined above.
res_sdf.
filter(func.col('user') == 1).
orderBy(['user', 'dt']).
show()
# +----+----------+-----+------+
# |user| dt|valor|cumsum|
# +----+----------+-----+------+
# | 1|2022-01-01| 1| 4|
# | 1|2022-01-02| -1| 3|
# | 1|2022-01-03| 0| 3|
# | 1|2022-01-04| 1| 4|
# | 1|2022-01-05| 0| 4|
# | 1|2022-01-06| 1| 5|
# | 1|2022-01-07| 0| 5|
# | 1|2022-01-08| 1| 5|
# | 1|2022-01-09| 0| 5|
# | 1|2022-01-10| -1| 4|
# | 1|2022-01-11| -1| 3|
# | 1|2022-01-12| -1| 2|
# | 1|2022-01-13| 1| 3|
# | 1|2022-01-14| -1| 2|
# | 1|2022-01-15| 1| 3|
# | 1|2022-01-16| -1| 2|
# | 1|2022-01-17| 0| 2|
# | 1|2022-01-18| 1| 3|
# | 1|2022-01-19| 0| 3|
# | 1|2022-01-20| -1| 2|
# +----+----------+-----+------+
res_sdf.
filter(func.col('user') == 2).
orderBy(['user', 'dt']).
show()
# +----+----------+-----+------+
# |user| dt|valor|cumsum|
# +----+----------+-----+------+
# | 2|2022-04-01| -1| 2|
# | 2|2022-04-02| 0| 2|
# | 2|2022-04-03| 1| 3|
# | 2|2022-04-04| -1| 2|
# | 2|2022-04-05| 1| 3|
# | 2|2022-04-06| 0| 3|
# | 2|2022-04-07| 1| 4|
# | 2|2022-04-08| -1| 3|
# | 2|2022-04-09| 0| 3|
# | 2|2022-04-10| 0| 3|
# | 2|2022-04-11| -1| 2|
# | 2|2022-04-12| 1| 3|
# | 2|2022-04-13| 0| 3|
# | 2|2022-04-14| 0| 3|
# | 2|2022-04-15| 1| 4|
# | 2|2022-04-16| -1| 3|
# | 2|2022-04-17| 0| 3|
# | 2|2022-04-18| 0| 3|
# | 2|2022-04-19| 1| 4|
# | 2|2022-04-20| 1| 5|
# +----+----------+-----+------+
# only showing top 20 rows
In such cases, we usually think of window functions to do a calculation going from one row to next. But this case is different, because the window should kind of keep track of itself. So window cannot help.
Main idea. Instead of operating with rows, one can do the work with grouped/aggregated arrays. In this case, it would work very well, because we do have a key to use in groupBy
, so the table will be divided into chunks of data, so the calculations will be parallelized.
Input:
from pyspark.sql import functions as F
df = spark.createDataFrame(
[(1, '2022-01-01', 0),
(1, '2022-01-02', 0),
(1, '2022-01-03', 1),
(1, '2022-01-04', 1),
(1, '2022-01-05', 1),
(1, '2022-01-06', 0),
(1, '2022-01-07', 0),
(1, '2022-01-08', 0),
(1, '2022-01-09', 1),
(1, '2022-01-10', 1),
(1, '2022-01-11', 1),
(1, '2022-01-12', 0),
(1, '2022-01-13', 0),
(1, '2022-01-14', -1),
(1, '2022-01-15', -1),
(1, '2022-01-16', -1),
(1, '2022-01-17', -1),
(1, '2022-01-18', -1),
(1, '2022-01-19', -1),
(1, '2022-01-20', 0)],
['user_id', 'date', 'valor'])
Script:
df = df.groupBy('user_id').agg(
F.aggregate(
F.array_sort(F.collect_list(F.struct('date', 'valor'))),
F.expr("array(struct(cast(null as string) date, 0L valor, 3L cum))"),
lambda acc, x: F.array_union(
acc,
F.array(x.withField(
'cum',
F.greatest(F.lit(1), F.least(F.lit(5), x['valor'] + F.element_at(acc, -1)['cum']))
))
)
).alias("a")
)
df = df.selectExpr("user_id", "inline(slice(a, 2, size(a)))")
df.show()
# +-------+----------+-----+---+
# |user_id| date|valor|cum|
# +-------+----------+-----+---+
# | 1|2022-01-01| 0| 3|
# | 1|2022-01-02| 0| 3|
# | 1|2022-01-03| 1| 4|
# | 1|2022-01-04| 1| 5|
# | 1|2022-01-05| 1| 5|
# | 1|2022-01-06| 0| 5|
# | 1|2022-01-07| 0| 5|
# | 1|2022-01-08| 0| 5|
# | 1|2022-01-09| 1| 5|
# | 1|2022-01-10| 1| 5|
# | 1|2022-01-11| 1| 5|
# | 1|2022-01-12| 0| 5|
# | 1|2022-01-13| 0| 5|
# | 1|2022-01-14| -1| 4|
# | 1|2022-01-15| -1| 3|
# | 1|2022-01-16| -1| 2|
# | 1|2022-01-17| -1| 1|
# | 1|2022-01-18| -1| 1|
# | 1|2022-01-19| -1| 1|
# | 1|2022-01-20| 0| 1|
# +-------+----------+-----+---+
Explanation
Groups are created based on "user_id". The aggregation for these groups lies in this line:
F.array_sort(F.collect_list(F.struct('date', 'valor')))
This creates arrays (collect_list
) for every "user_id". These arrays contain structs of 2 fields: date and value.
+-------+-----------------------------------------------+
|user_id|a |
+-------+-----------------------------------------------+
|1 |[{2022-01-01, 0}, {2022-01-02, 0}, {...} ... ] |
+-------+-----------------------------------------------+
array_sort
is used to make sure all the structs inside are sorted, because other steps will depend on it.
All the rest what’s inside agg
is for transforming the result of the above aggregation.
The main part in the code is aggregate
. It takes an array, "loops" through every element and returns one value (in our case, this value is made to be array too). It works like this… You take the initial value (array(struct(cast(null as string) date, 0L valor, 3L cum))
and merge it with the first element in the array using the provided function (lambda
). The result is then used in place of initial value for the next run. You do the merge again, but with the following element in the array. And so on.
In this case, the lambda
function performs array_union
, which makes a union of arrays having identic schemas.
-
We take the initial value (array of structs) as
acc
variable
[{null, 0, 3}]
(it’s already ready to be used inarray_union
) -
take the first element inside ‘a’ column’s array (i.e. ) as
x
variable
{2022-01-01, 0}
(it’s a struct, so the schema is not the same withacc
(array of structs), so some processing is needed, and also the calculation needs to be done at this step, as we have access to both of the variables at this point) -
we’ll create the array of structs by enclosing the
x
struct insideF.array()
; also, we’ll have to add one more field to the struct, asx
struct currently has just 2 fields
F.array(x.withField('cum', ...))
-
inside the
.withField()
we have to provide the expression for the fieldF.greatest( F.lit(1), F.least( F.lit(5), x['valor'] + F.element_at(acc, -1)['cum'] ) )
element_at(acc, -1)
takes the last struct ofacc
array
['cum']
takes the field ‘cum’ from the struct
x['valor'] +
adds ‘valor’ field from thex
struct
F.least()
assures that the max value in ‘cum’ will stay 5 (takes the min value from the new ‘cum’ and 5)
F.greatest()
assures that the min value in ‘cum’ will stay 1 -
both
acc
and the newly created array of structs now have identic schemas and proper data, so they can be unionized
array_union
the result is now being assigned toacc
variable, whilex
variable gets assigned the next value from the ‘a’ array.
The process continues from step 3.
Finally, the result of aggregate
looks like
[{null, 0, 3}, {2022-01-01, 0, 3}, {2022-01-02, 0, 3}, {2022-01-03, 1, 4}, {...} ... ]
The first element is removed using slice(..., 2, size(a))
inline
is used to explode the array of structs.
Note. It’s important to create the initial value of aggregate
such that it would contain proper schema (column/field names and types):
F.expr("array(struct(cast(null as string) date, 0L valor, 3L cum))")
Those L
letters tell that 0
and 3
are of bigint (long) data type. (sql-ref-literals)
The same could have been written like this:
F.expr("array(struct(null, 0, 3))").cast('array<struct<date:string,valor:bigint,cum:bigint>>')