Explode dates and backfill rows in pyspark dataframe

Question:

I have this dataframe:

+---+----------+------+
| id|      date|amount|
+---+----------+------+
|123|2022-11-11|100.00|
|123|2022-11-12|100.00|
|123|2022-11-13|100.00|
|123|2022-11-14|200.00|
|456|2022-11-14|300.00|
|456|2022-11-15|300.00|
|456|2022-11-16|300.00|
|789|2022-11-11|400.00|
|789|2022-11-12|500.00|
+---+----------+------+

I need to create new records for each date until current_date() - 2. And the value that will be populated must be the most recent one.

For example, if date_sub(current_date(), 2) == "2022-11-16" then I need the following dataframe:

+------+----------+-------+
|id    |    date  | amount|
+------+----------+-------+
|   123|2022-11-11|100,00 |
|   123|2022-11-12|100,00 |
|   123|2022-11-13|100,00 |
|   123|2022-11-14|200,00 |
|   123|2022-11-15|200,00 |
|   123|2022-11-16|200,00 |
|   456|2022-11-14|300,00 |
|   456|2022-11-15|300,00 |
|   456|2022-11-16|300,00 |
|   789|2022-11-11|400,00 |
|   789|2022-11-12|500,00 |
|   789|2022-11-13|500,00 |
|   789|2022-11-14|500,00 |
|   789|2022-11-15|500,00 |
|   789|2022-11-16|500,00 |
+------+----------+-------+
import findspark
findspark.init()

import pyspark
from pyspark.sql import SparkSession

spark = SparkSession.builder.master("local[4]").appName("Complete Rows").getOrCreate()

from pyspark.sql.functions import *
from pyspark.sql.types import StructType,StructField, IntegerType, DateType, DecimalType
from datetime import datetime
from decimal import Decimal

vdata = [
    (123,datetime.strptime('2022-11-11','%Y-%m-%d'),Decimal(100)),
    (123,datetime.strptime('2022-11-12','%Y-%m-%d'),Decimal(100)),
    (123,datetime.strptime('2022-11-13','%Y-%m-%d'),Decimal(100)),
    (123,datetime.strptime('2022-11-14','%Y-%m-%d'),Decimal(200)),
    (456,datetime.strptime('2022-11-14','%Y-%m-%d'),Decimal(300)),
    (456,datetime.strptime('2022-11-15','%Y-%m-%d'),Decimal(300)),
    (456,datetime.strptime('2022-11-16','%Y-%m-%d'),Decimal(300)),
    (789,datetime.strptime('2022-11-11','%Y-%m-%d'),Decimal(400)),
    (789,datetime.strptime('2022-11-12','%Y-%m-%d'),Decimal(500))]

schema = StructType([
    StructField("id",IntegerType(),False),
    StructField("date",DateType(),False),
    StructField("amount",DecimalType(10,2),False)])

df = spark.createDataFrame(vdata,schema)

df.show()

I tried to identify the maximum date for each ID, then identify the last value for that maximum date and do an F.expr(sequence) to create a list of records and then explode to create the lines, but it’s not working very well. Thanks for any help you can give!

Asked By: TRCL

||

Answers:

I managed to find the following solution.
For clarification purposes I divided it in three steps; of course you can write fewer lines of code if you make them more compact.

1) Lookup

Create a lookup table with all the necessary dates (both present and not) for each id.

import pyspark.sql.functions as F
from pyspark.sql.window import Window

lookup = (df
          .groupby('id')
          .agg(
            F.min('date').alias('start_date'),
            F.date_sub(F.current_date(), 2).alias('end_date')
          )
          .select('id', F.explode(F.expr('sequence(start_date, end_date, interval 1 day)')).alias('date'))
         )
lookup.show()

+---+----------+
| id|      date|
+---+----------+
|123|2022-11-11|
|123|2022-11-12|
|123|2022-11-13|
|123|2022-11-14|
|123|2022-11-15|
|123|2022-11-16|
|456|2022-11-14|
|456|2022-11-15|
|456|2022-11-16|
|789|2022-11-11|
|789|2022-11-12|
|789|2022-11-13|
|789|2022-11-14|
|789|2022-11-15|
|789|2022-11-16|
+---+----------+

2) Join

Afterwards, we join the lookup table with our original dataframe: in this way the necessary rows are added with amount variable set as null.

df = df.join(lookup, on=['id', 'date'], how='outer')
df.show()

+---+----------+------+
| id|      date|amount|
+---+----------+------+
|123|2022-11-11| 100.0|
|123|2022-11-12| 100.0|
|123|2022-11-13| 100.0|
|123|2022-11-14| 200.0|
|123|2022-11-15|  null|
|123|2022-11-16|  null|
|456|2022-11-14| 300.0|
|456|2022-11-15| 300.0|
|456|2022-11-16| 300.0|
|789|2022-11-11| 400.0|
|789|2022-11-12| 500.0|
|789|2022-11-13|  null|
|789|2022-11-14|  null|
|789|2022-11-15|  null|
|789|2022-11-16|  null|
+---+----------+------+

3) last function

We use the last function with ignorenulls=True to retrieve the last non-null value within a window partitioned by id and ordered by date.

w = Window.partitionBy('id').orderBy('date').rowsBetween(Window.unboundedPreceding, 0)

df = df.withColumn('amount', F.last('amount', ignorenulls=True).over(w))
df.show()

+---+----------+------+
| id|      date|amount|
+---+----------+------+
|123|2022-11-11| 100.0|
|123|2022-11-12| 100.0|
|123|2022-11-13| 100.0|
|123|2022-11-14| 200.0|
|123|2022-11-15| 200.0|
|123|2022-11-16| 200.0|
|456|2022-11-14| 300.0|
|456|2022-11-15| 300.0|
|456|2022-11-16| 300.0|
|789|2022-11-11| 400.0|
|789|2022-11-12| 500.0|
|789|2022-11-13| 500.0|
|789|2022-11-14| 500.0|
|789|2022-11-15| 500.0|
|789|2022-11-16| 500.0|
+---+----------+------+
Answered By: Ric S