Python-polars: rolling_sum where the window_size from another column

Question:

Consider the following dataframe:

df = pl.DataFrame(
    {
        "date": pl.date_range(
            low=datetime(2023, 2, 1),
            high=datetime(2023, 2, 5),
            interval="1d"),
        "periods": [2, 2, 2, 1, 1],
        "quantity": [10, 12, 14, 16, 18],
        "calculate": [22, 26, 30, 16, 18]
    }
)

The column calculate is what I want. This is done by a rolling_sum where the window_size parameter is taken from the periods column, rather than a fixed value.

I can do the following (window_size=2):

df.select(pl.col("quantity").rolling_sum(window_size=2))

However, I get an error when I try and do this:

df.select(pl.col("quantity").rolling_sum(window_size=pl.col("periods")))

This is the error –

TypeError: argument 'window_size': 'Expr' object cannot be converted to 'PyString'

How do I pass the value of window_size based on another column? I also looked at using groupie_rolling but could not figure it out as well.

Asked By: anerjee

||

Answers:

It seems like this should be easier to do which suggests I may be missing something obvious.

As a workaround – you could use the row count to generate row indexes for the windows.

df = (  
   df
   .with_row_count()
   .with_columns(
      window = 
         pl.arange(
            pl.col("row_nr"), 
            pl.col("row_nr") + pl.col("periods")))
)      
shape: (5, 6)
┌────────┬─────────────────────┬─────────┬──────────┬───────────┬───────────┐
│ row_nr | date                | periods | quantity | calculate | window    │
│ ---    | ---                 | ---     | ---      | ---       | ---       │
│ u32    | datetime[μs]        | i64     | i64      | i64       | list[i64] │
╞════════╪═════════════════════╪═════════╪══════════╪═══════════╪═══════════╡
│ 0      | 2023-02-01 00:00:00 | 2       | 10       | 22        | [0, 1]    │
│ 1      | 2023-02-02 00:00:00 | 2       | 12       | 26        | [1, 2]    │
│ 2      | 2023-02-03 00:00:00 | 2       | 14       | 30        | [2, 3]    │
│ 3      | 2023-02-04 00:00:00 | 1       | 16       | 16        | [3]       │
│ 4      | 2023-02-05 00:00:00 | 1       | 18       | 18        | [4]       │
└────────┴─────────────────────┴─────────┴──────────┴───────────┴───────────┘

.take() can’t unpack a list, but you can .flatten() it – you lose the original shape though.

df.select(rolling = pl.col("quantity").take(pl.col("window").flatten()))
shape: (8, 1)
┌─────────┐
│ rolling │
│ ---     │
│ i64     │
╞═════════╡
│ 10      │
│ 12      │
│ 12      │
│ 14      │
│ 14      │
│ 16      │
│ 16      │
│ 18      │
└─────────┘

You could .explode() the dataframe to make it the same length, then .groupby() it back together.

(
   df
   .explode("window")
   .with_columns(
      df.select(
         rolling = 
            pl.col("quantity")
              .take(pl.col("window").flatten())))
   .groupby("row_nr", maintain_order=True)
   .agg([
      pl.exclude("rolling").first(), 
      pl.col("rolling").sum()
   ])
)
shape: (5, 7)
┌────────┬─────────────────────┬─────────┬──────────┬───────────┬────────┬─────────┐
│ row_nr | date                | periods | quantity | calculate | window | rolling │
│ ---    | ---                 | ---     | ---      | ---       | ---    | ---     │
│ u32    | datetime[μs]        | i64     | i64      | i64       | i64    | i64     │
╞════════╪═════════════════════╪═════════╪══════════╪═══════════╪════════╪═════════╡
│ 0      | 2023-02-01 00:00:00 | 2       | 10       | 22        | 0      | 22      │
│ 1      | 2023-02-02 00:00:00 | 2       | 12       | 26        | 1      | 26      │
│ 2      | 2023-02-03 00:00:00 | 2       | 14       | 30        | 2      | 30      │
│ 3      | 2023-02-04 00:00:00 | 1       | 16       | 16        | 3      | 16      │
│ 4      | 2023-02-05 00:00:00 | 1       | 18       | 18        | 4      | 18      │
└────────┴─────────────────────┴─────────┴──────────┴───────────┴────────┴─────────┘
Answered By: jqurious

Very similar to @jqurious’s but (I think) a bit simplified

df.lazy() 
    .with_row_count('i') 
    .with_columns(
        window = 
            pl.arange(
                pl.col("i"), 
                pl.col("i") + pl.col("periods")),
            qty=pl.col('quantity').list()
) 
.with_columns(
    rollsum=pl.col('qty').arr.take(pl.col('window')).arr.sum()
) 
.select(pl.exclude(['window','qty','i'])) 
.collect()

It works on the same concept but it just essentially recreates the whole quantity column as a list then using the window column to filter that list to the corresponding values and sum them up.

Another method is to just use a loop which will be more memory efficient.

First, you want to get all the unique values of periods, then initialize a column in the df for the rolling_sum, reverse the order, and then replace the column with a calculation for every period. At the end, put the rows back in the original order.

periods=df.get_column('periods').unique()
df=df.with_columns(pl.lit(None).cast(pl.Float64()).alias("rollsum")).sort('date',reverse=True)
for period in periods:
    df=df.with_columns((pl.when(pl.col('periods')==period).then(pl.col('quantity').rolling_sum(window_size=period)).otherwise(pl.col('rollsum'))).alias('rollsum'))
df=df.sort('date')
df
Answered By: Dean MacGregor
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.