PySpark: Last value of previous month by group avoiding self-join
Question:
I would like to obtain the last value an attribute takes per group over the previous month.
I can achieve this with a self-join like so:
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.window import Window
df = (
spark.createDataFrame(
[
("2022-07-29", 1, 1),
("2022-07-30", 1, 2),
("2022-07-31", 1, 3),
("2022-08-01", 1, 4),
("2022-08-02", 1, 5),
("2022-08-03", 1, 6),
("2022-09-10", 1, 8),
("2022-09-11", 1, 9),
("2022-09-12", 1, 10),
("2022-07-29", 2, 7),
("2022-07-30", 2, 6),
("2022-07-31", 2, 5),
("2022-08-01", 2, 4),
("2022-08-02", 2, 3),
("2022-08-03", 2, 2),
("2022-09-10", 2, 8),
("2022-09-11", 2, 9),
("2022-09-12", 2, 10),
],
["date","id","value"]
)
.withColumn("date", F.to_date(F.col("date")))
)
w = Window.partitionBy("id", "month").orderBy(F.col("date").desc())
df = (
df
.withColumn("month", F.date_trunc("month", F.col("date")))
.join(
df
.withColumn("month", F.add_months(F.date_trunc("month", F.col("date")), 1))
.withColumn("last_value_prev_month", F.first(F.col("value")).over(w))
.select("id", "month", "last_value_prev_month")
.drop_duplicates(subset=["id", "month"]),
on=["id", "month"],
how="left"
)
.drop("month")
.orderBy(["id", "date"])
)
df.show()
+---+----------+-----+---------------------+
| id| date|value|last_value_prev_month|
+---+----------+-----+---------------------+
| 1|2022-07-29| 1| null|
| 1|2022-07-30| 2| null|
| 1|2022-07-31| 3| null|
| 1|2022-08-01| 4| 3|
| 1|2022-08-02| 5| 3|
| 1|2022-08-03| 6| 3|
| 1|2022-09-10| 8| 6|
| 1|2022-09-11| 9| 6|
| 1|2022-09-12| 10| 6|
| 2|2022-07-29| 7| null|
| 2|2022-07-30| 6| null|
| 2|2022-07-31| 5| null|
| 2|2022-08-01| 4| 5|
| 2|2022-08-02| 3| 5|
| 2|2022-08-03| 2| 5|
| 2|2022-09-10| 8| 2|
| 2|2022-09-11| 9| 2|
| 2|2022-09-12| 10| 2|
+---+----------+-----+---------------------+
This seems inefficient to me.
Can this be done with just a window, avoiding a self-join?
Answers:
Yes, we can do it using window functions to avoid a join.
data_sdf.
withColumn('mth', func.month('date')).
withColumn('blah',
(func.col('mth') != func.lag('mth').over(wd.partitionBy('id').orderBy('date'))).cast('int')
).
withColumn('blah2',
func.when(func.col('blah') == 1,
func.lag('value').over(wd.partitionBy('id').orderBy('date'))
)
).
withColumn('last_value_prev_month',
func.last('blah2', ignorenulls=True).over(wd.partitionBy('id').orderBy('date'))
).
drop('mth', 'blah', 'blah2').
show()
# +----------+---+-----+---------------------+
# | date| id|value|last_value_prev_month|
# +----------+---+-----+---------------------+
# |2022-07-29| 1| 1| null|
# |2022-07-30| 1| 2| null|
# |2022-07-31| 1| 3| null|
# |2022-08-01| 1| 4| 3|
# |2022-08-02| 1| 5| 3|
# |2022-08-03| 1| 6| 3|
# |2022-09-10| 1| 8| 6|
# |2022-09-11| 1| 9| 6|
# |2022-09-12| 1| 10| 6|
# |2022-07-29| 2| 7| null|
# |2022-07-30| 2| 6| null|
# |2022-07-31| 2| 5| null|
# |2022-08-01| 2| 4| 5|
# |2022-08-02| 2| 3| 5|
# |2022-08-03| 2| 2| 5|
# |2022-09-10| 2| 8| 2|
# |2022-09-11| 2| 9| 2|
# |2022-09-12| 2| 10| 2|
# +----------+---+-----+---------------------+
blah
flags the record of the first date in a month.
blah2
sets the lag of value
for the aforementioned record. i.e. value
in the last date of the previous month.
- use
last()
window function on the aforementioned blah2
field to fill the nulls.
samkart has provided the main idea of the answer above. Here I provide a solution with two instead of three windows.
days = lambda x: x * 86400
w1 = (
Window
.partitionBy("id")
.orderBy(F.col("date").cast("timestamp").cast("long"))
.rangeBetween(Window.unboundedPreceding, -days(1))
)
w2 = (
Window.
partitionBy("id", F.date_trunc("month", "date"))
.orderBy(F.col("date"))
)
(
df
.withColumn("value_prev_day", F.last("value").over(w1))
.withColumn("last_value_prev_month", F.first("value_prev_day").over(w2))
.orderBy(["id", "date"])
.show()
)
+----------+---+-----+--------------+---------------------+
| date| id|value|value_prev_day|last_value_prev_month|
+----------+---+-----+--------------+---------------------+
|2022-07-29| 1| 1| null| null|
|2022-07-30| 1| 2| 1| null|
|2022-07-31| 1| 3| 2| null|
|2022-08-01| 1| 4| 3| 3|
|2022-08-02| 1| 5| 4| 3|
|2022-08-03| 1| 6| 5| 3|
|2022-09-10| 1| 8| 6| 6|
|2022-09-11| 1| 9| 8| 6|
|2022-09-12| 1| 10| 9| 6|
|2022-07-29| 2| 7| null| null|
|2022-07-30| 2| 6| 7| null|
|2022-07-31| 2| 5| 6| null|
|2022-08-01| 2| 4| 5| 5|
|2022-08-02| 2| 3| 4| 5|
|2022-08-03| 2| 2| 3| 5|
|2022-09-10| 2| 8| 2| 2|
|2022-09-11| 2| 9| 8| 2|
|2022-09-12| 2| 10| 9| 2|
+----------+---+-----+--------------+---------------------+
value_prev_day
is value of value
on the previous day (per id
)
- Once we have this, we can create another partition of the data, by
id
and the month of the date
for the current row. We then order this partition by date
, meaning that the first of the month is the first
row in the partition. We assign last_value_prev_month
as first(value_prev_day)
over this partition. This has to be last value of the previous month, since it is the value_prev_day
of the first of the month.
I would like to obtain the last value an attribute takes per group over the previous month.
I can achieve this with a self-join like so:
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.window import Window
df = (
spark.createDataFrame(
[
("2022-07-29", 1, 1),
("2022-07-30", 1, 2),
("2022-07-31", 1, 3),
("2022-08-01", 1, 4),
("2022-08-02", 1, 5),
("2022-08-03", 1, 6),
("2022-09-10", 1, 8),
("2022-09-11", 1, 9),
("2022-09-12", 1, 10),
("2022-07-29", 2, 7),
("2022-07-30", 2, 6),
("2022-07-31", 2, 5),
("2022-08-01", 2, 4),
("2022-08-02", 2, 3),
("2022-08-03", 2, 2),
("2022-09-10", 2, 8),
("2022-09-11", 2, 9),
("2022-09-12", 2, 10),
],
["date","id","value"]
)
.withColumn("date", F.to_date(F.col("date")))
)
w = Window.partitionBy("id", "month").orderBy(F.col("date").desc())
df = (
df
.withColumn("month", F.date_trunc("month", F.col("date")))
.join(
df
.withColumn("month", F.add_months(F.date_trunc("month", F.col("date")), 1))
.withColumn("last_value_prev_month", F.first(F.col("value")).over(w))
.select("id", "month", "last_value_prev_month")
.drop_duplicates(subset=["id", "month"]),
on=["id", "month"],
how="left"
)
.drop("month")
.orderBy(["id", "date"])
)
df.show()
+---+----------+-----+---------------------+
| id| date|value|last_value_prev_month|
+---+----------+-----+---------------------+
| 1|2022-07-29| 1| null|
| 1|2022-07-30| 2| null|
| 1|2022-07-31| 3| null|
| 1|2022-08-01| 4| 3|
| 1|2022-08-02| 5| 3|
| 1|2022-08-03| 6| 3|
| 1|2022-09-10| 8| 6|
| 1|2022-09-11| 9| 6|
| 1|2022-09-12| 10| 6|
| 2|2022-07-29| 7| null|
| 2|2022-07-30| 6| null|
| 2|2022-07-31| 5| null|
| 2|2022-08-01| 4| 5|
| 2|2022-08-02| 3| 5|
| 2|2022-08-03| 2| 5|
| 2|2022-09-10| 8| 2|
| 2|2022-09-11| 9| 2|
| 2|2022-09-12| 10| 2|
+---+----------+-----+---------------------+
This seems inefficient to me.
Can this be done with just a window, avoiding a self-join?
Yes, we can do it using window functions to avoid a join.
data_sdf.
withColumn('mth', func.month('date')).
withColumn('blah',
(func.col('mth') != func.lag('mth').over(wd.partitionBy('id').orderBy('date'))).cast('int')
).
withColumn('blah2',
func.when(func.col('blah') == 1,
func.lag('value').over(wd.partitionBy('id').orderBy('date'))
)
).
withColumn('last_value_prev_month',
func.last('blah2', ignorenulls=True).over(wd.partitionBy('id').orderBy('date'))
).
drop('mth', 'blah', 'blah2').
show()
# +----------+---+-----+---------------------+
# | date| id|value|last_value_prev_month|
# +----------+---+-----+---------------------+
# |2022-07-29| 1| 1| null|
# |2022-07-30| 1| 2| null|
# |2022-07-31| 1| 3| null|
# |2022-08-01| 1| 4| 3|
# |2022-08-02| 1| 5| 3|
# |2022-08-03| 1| 6| 3|
# |2022-09-10| 1| 8| 6|
# |2022-09-11| 1| 9| 6|
# |2022-09-12| 1| 10| 6|
# |2022-07-29| 2| 7| null|
# |2022-07-30| 2| 6| null|
# |2022-07-31| 2| 5| null|
# |2022-08-01| 2| 4| 5|
# |2022-08-02| 2| 3| 5|
# |2022-08-03| 2| 2| 5|
# |2022-09-10| 2| 8| 2|
# |2022-09-11| 2| 9| 2|
# |2022-09-12| 2| 10| 2|
# +----------+---+-----+---------------------+
blah
flags the record of the first date in a month.blah2
sets the lag ofvalue
for the aforementioned record. i.e.value
in the last date of the previous month.- use
last()
window function on the aforementionedblah2
field to fill the nulls.
samkart has provided the main idea of the answer above. Here I provide a solution with two instead of three windows.
days = lambda x: x * 86400
w1 = (
Window
.partitionBy("id")
.orderBy(F.col("date").cast("timestamp").cast("long"))
.rangeBetween(Window.unboundedPreceding, -days(1))
)
w2 = (
Window.
partitionBy("id", F.date_trunc("month", "date"))
.orderBy(F.col("date"))
)
(
df
.withColumn("value_prev_day", F.last("value").over(w1))
.withColumn("last_value_prev_month", F.first("value_prev_day").over(w2))
.orderBy(["id", "date"])
.show()
)
+----------+---+-----+--------------+---------------------+
| date| id|value|value_prev_day|last_value_prev_month|
+----------+---+-----+--------------+---------------------+
|2022-07-29| 1| 1| null| null|
|2022-07-30| 1| 2| 1| null|
|2022-07-31| 1| 3| 2| null|
|2022-08-01| 1| 4| 3| 3|
|2022-08-02| 1| 5| 4| 3|
|2022-08-03| 1| 6| 5| 3|
|2022-09-10| 1| 8| 6| 6|
|2022-09-11| 1| 9| 8| 6|
|2022-09-12| 1| 10| 9| 6|
|2022-07-29| 2| 7| null| null|
|2022-07-30| 2| 6| 7| null|
|2022-07-31| 2| 5| 6| null|
|2022-08-01| 2| 4| 5| 5|
|2022-08-02| 2| 3| 4| 5|
|2022-08-03| 2| 2| 3| 5|
|2022-09-10| 2| 8| 2| 2|
|2022-09-11| 2| 9| 8| 2|
|2022-09-12| 2| 10| 9| 2|
+----------+---+-----+--------------+---------------------+
value_prev_day
is value ofvalue
on the previous day (perid
)- Once we have this, we can create another partition of the data, by
id
and the month of thedate
for the current row. We then order this partition bydate
, meaning that the first of the month is thefirst
row in the partition. We assignlast_value_prev_month
asfirst(value_prev_day)
over this partition. This has to be last value of the previous month, since it is thevalue_prev_day
of the first of the month.