PySpark: Last value of previous month by group avoiding self-join

Question:

I would like to obtain the last value an attribute takes per group over the previous month.

I can achieve this with a self-join like so:

from pyspark.sql import SparkSession, functions as F
from pyspark.sql.window import Window

df = (
    spark.createDataFrame(
    [
        ("2022-07-29", 1, 1),
        ("2022-07-30", 1, 2),
        ("2022-07-31", 1, 3),
        ("2022-08-01", 1, 4),
        ("2022-08-02", 1, 5),
        ("2022-08-03", 1, 6), 
        ("2022-09-10", 1, 8),
        ("2022-09-11", 1, 9),
        ("2022-09-12", 1, 10), 
        ("2022-07-29", 2, 7),
        ("2022-07-30", 2, 6),
        ("2022-07-31", 2, 5),
        ("2022-08-01", 2, 4),
        ("2022-08-02", 2, 3),
        ("2022-08-03", 2, 2),  
        ("2022-09-10", 2, 8),
        ("2022-09-11", 2, 9),
        ("2022-09-12", 2, 10), 
    ],
            ["date","id","value"]
    )
    .withColumn("date", F.to_date(F.col("date")))
)

w = Window.partitionBy("id", "month").orderBy(F.col("date").desc())
df = (
    df
    .withColumn("month", F.date_trunc("month", F.col("date")))
    .join(
        df
        .withColumn("month", F.add_months(F.date_trunc("month", F.col("date")), 1))
        .withColumn("last_value_prev_month", F.first(F.col("value")).over(w))
        .select("id", "month", "last_value_prev_month")
        .drop_duplicates(subset=["id", "month"]),
        on=["id", "month"],
        how="left"
    )
    .drop("month")
    .orderBy(["id", "date"])
)
df.show()

+---+----------+-----+---------------------+
| id|      date|value|last_value_prev_month|
+---+----------+-----+---------------------+
|  1|2022-07-29|    1|                 null|
|  1|2022-07-30|    2|                 null|
|  1|2022-07-31|    3|                 null|
|  1|2022-08-01|    4|                    3|
|  1|2022-08-02|    5|                    3|
|  1|2022-08-03|    6|                    3|
|  1|2022-09-10|    8|                    6|
|  1|2022-09-11|    9|                    6|
|  1|2022-09-12|   10|                    6|
|  2|2022-07-29|    7|                 null|
|  2|2022-07-30|    6|                 null|
|  2|2022-07-31|    5|                 null|
|  2|2022-08-01|    4|                    5|
|  2|2022-08-02|    3|                    5|
|  2|2022-08-03|    2|                    5|
|  2|2022-09-10|    8|                    2|
|  2|2022-09-11|    9|                    2|
|  2|2022-09-12|   10|                    2|
+---+----------+-----+---------------------+

This seems inefficient to me.

Can this be done with just a window, avoiding a self-join?

Asked By: user10166790

||

Answers:

Yes, we can do it using window functions to avoid a join.

data_sdf. 
    withColumn('mth', func.month('date')). 
    withColumn('blah', 
               (func.col('mth') != func.lag('mth').over(wd.partitionBy('id').orderBy('date'))).cast('int')
               ). 
    withColumn('blah2', 
               func.when(func.col('blah') == 1, 
                         func.lag('value').over(wd.partitionBy('id').orderBy('date'))
                         )
               ). 
    withColumn('last_value_prev_month', 
               func.last('blah2', ignorenulls=True).over(wd.partitionBy('id').orderBy('date'))
               ). 
    drop('mth', 'blah', 'blah2'). 
    show()

# +----------+---+-----+---------------------+
# |      date| id|value|last_value_prev_month|
# +----------+---+-----+---------------------+
# |2022-07-29|  1|    1|                 null|
# |2022-07-30|  1|    2|                 null|
# |2022-07-31|  1|    3|                 null|
# |2022-08-01|  1|    4|                    3|
# |2022-08-02|  1|    5|                    3|
# |2022-08-03|  1|    6|                    3|
# |2022-09-10|  1|    8|                    6|
# |2022-09-11|  1|    9|                    6|
# |2022-09-12|  1|   10|                    6|
# |2022-07-29|  2|    7|                 null|
# |2022-07-30|  2|    6|                 null|
# |2022-07-31|  2|    5|                 null|
# |2022-08-01|  2|    4|                    5|
# |2022-08-02|  2|    3|                    5|
# |2022-08-03|  2|    2|                    5|
# |2022-09-10|  2|    8|                    2|
# |2022-09-11|  2|    9|                    2|
# |2022-09-12|  2|   10|                    2|
# +----------+---+-----+---------------------+
  • blah flags the record of the first date in a month.
  • blah2 sets the lag of value for the aforementioned record. i.e. value in the last date of the previous month.
  • use last() window function on the aforementioned blah2 field to fill the nulls.
Answered By: samkart

samkart has provided the main idea of the answer above. Here I provide a solution with two instead of three windows.

days = lambda x: x * 86400
w1 = (
    Window
    .partitionBy("id")
    .orderBy(F.col("date").cast("timestamp").cast("long"))
    .rangeBetween(Window.unboundedPreceding, -days(1))
)
w2 = (
    Window.
    partitionBy("id", F.date_trunc("month", "date"))
    .orderBy(F.col("date"))
)

(
    df
    .withColumn("value_prev_day", F.last("value").over(w1))
    .withColumn("last_value_prev_month", F.first("value_prev_day").over(w2))
    .orderBy(["id", "date"])
    .show()
)

+----------+---+-----+--------------+---------------------+
|      date| id|value|value_prev_day|last_value_prev_month|
+----------+---+-----+--------------+---------------------+
|2022-07-29|  1|    1|          null|                 null|
|2022-07-30|  1|    2|             1|                 null|
|2022-07-31|  1|    3|             2|                 null|
|2022-08-01|  1|    4|             3|                    3|
|2022-08-02|  1|    5|             4|                    3|
|2022-08-03|  1|    6|             5|                    3|
|2022-09-10|  1|    8|             6|                    6|
|2022-09-11|  1|    9|             8|                    6|
|2022-09-12|  1|   10|             9|                    6|
|2022-07-29|  2|    7|          null|                 null|
|2022-07-30|  2|    6|             7|                 null|
|2022-07-31|  2|    5|             6|                 null|
|2022-08-01|  2|    4|             5|                    5|
|2022-08-02|  2|    3|             4|                    5|
|2022-08-03|  2|    2|             3|                    5|
|2022-09-10|  2|    8|             2|                    2|
|2022-09-11|  2|    9|             8|                    2|
|2022-09-12|  2|   10|             9|                    2|
+----------+---+-----+--------------+---------------------+
  • value_prev_day is value of value on the previous day (per id)
  • Once we have this, we can create another partition of the data, by id and the month of the date for the current row. We then order this partition by date, meaning that the first of the month is the first row in the partition. We assign last_value_prev_month as first(value_prev_day) over this partition. This has to be last value of the previous month, since it is the value_prev_day of the first of the month.
Answered By: user10166790
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.