Get next value from a PySpark DataFrame column based on condition

Question:

I have a dataset such as:

date is_business_day
2023-01-01 0
2023-01-02 1
2023-01-03 1
2023-01-04 1
2023-01-05 1
2023-01-06 1
2023-01-07 0
2023-01-08 0
2023-01-09 1
2023-04-06 1
2023-04-07 0
2023-04-08 0
2023-04-09 0
2023-04-10 1

I would like to get the next value from date column when condition is_business_day == 1 was met

The desired output would be something like:

date is_business_day next_business_day
2023-01-01 0 2023-01-02
2023-01-02 1 2023-01-03
2023-01-03 1 2023-01-04
2023-01-04 1 2023-01-05
2023-01-05 1 2023-01-06
2023-01-06 1 2023-01-09
2023-01-07 0 2023-01-09
2023-01-08 0 2023-01-09
2023-01-09 1 2023-01-10
2023-01-10 1 2023-01-14
2023-01-11 0 2023-01-14
2023-01-12 0 2023-01-14
2023-01-13 0 2023-01-14
2023-01-14 1

I have created a function such as below:

def next_business_day(df_calendar, date):
    date_f = datetime.strptime(date, '%Y-%m-%d')
    next_day = (date_f + timedelta(days=1)).strftime('%Y-%m-%d')

    # Filtering DataFrame to only get the dates AFTER the date that we
    # are checking.
    df_calendar_next_days = df_calendar.filter(col('date') >= next_day)

    # Creates an `list` that contains all rows from DataFrame
    # so we can iterate over it.
    df_it = df_calendar_next_days.collect()

    is_business_day = 0
    while is_business_day == 0:
        for row in df_it:
            is_business_day = row['is_business_day']

            # If is "is_business_day == 1" then return the date
            # else iterate over the next row of DataFrame
            if is_business_day == 1:
                return row['date']

The function works but I can’t use this function on ".withColum()" because I can’t pass the DataFrame as parameter.

If I try to do a code like this:

df_calendar = (
    df_calendar
        .withColumn('next_business_day', next_business_day(df_calendar, col('date')))
)

I receive the error:

TypeError: Invalid argument, not a string or column: DataFrame[date: date] of type <class 'pyspark.sql.dataframe.DataFrame'>. For column literals, use 'lit', 'array', 'struct' or 'create_map' function.
Asked By: Marcos Martins

||

Answers:

I didn’t debug what is wrong with the current code but what you want to achieve can be done with Pyspark’s built-in function.

With conditional F.min function, it will look for the minimum date when is_business_day == 1 and only look up from the next row(current row + 1).

from pyspark.sql import functions as F

w = Window.orderBy('date').rowsBetween(Window.currentRow + 1, Window.unboundedFollowing)

df = df.withColumn('next_business_day', F.min(F.when(F.col('is_business_day') == 1, F.col('date'))).over(w))
Answered By: Emma
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.