Element-wise aggregation of a column of type List[f64] in PyPolars

Question:

I want to apply aggregation functions like sum, mean, etc to a column of type List[f64] after a groupby such that I get a List[64] entry back.
Say I have:

import polars as pl

df = pl.DataFrame(
    {
        "Case": ["case1", "case1"],
        "List": [[1, 2, 3], [4, 5, 6]],
    }
)

print(df)
shape: (2, 2)
┌───────┬────────────┐
│ Case  ┆ List       │
│ ---   ┆ ---        │
│ str   ┆ list[i64]  │
╞═══════╪════════════╡
│ case1 ┆ [1, 2, 3]  │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
│ case1 ┆ [4, 5, 6]  │
└───────┴────────────┘

I want to groupby Case and sum List so that I end up with:

┌───────┬────────────┐
│ Case  ┆ List       │
│ ---   ┆ ---        │
│ str   ┆ list[i64]  │
╞═══════╪════════════╡
│ case1 ┆ [5, 7, 9]  │
└───────┴────────────┘

How would I best do this? Note that the length of each of the lists are 256, so indexing each of them is not a good solution.

Thanks!

Asked By: dashdeckers

||

Answers:

Note that the length of each of the lists are 256, so indexing each of them is not a good solution.

If you are sure of the length of your lists ahead of time, then we can avoid the typical explode/index/groupby solution as follows:

list_size = 3
(
    df.groupby("Case")
    .agg(
        pl.concat_list(
            [
                pl.col("List")
                .arr.slice(n, 1)
                .arr.first()
                .sum()
                for n in range(0, list_size)
            ]
        )
    )
)
shape: (1, 2)
┌───────┬───────────┐
│ Case  ┆ List      │
│ ---   ┆ ---       │
│ str   ┆ list[i64] │
╞═══════╪═══════════╡
│ case1 ┆ [5, 7, 9] │
└───────┴───────────┘

How it works

To see how this works, let’s look at how the algorithm adds the first elements of each list. (We can extrapolate for all elements from this example.)

In the first step, we use groupby to accumulate all the lists for each Case.

(
    df.groupby("Case")
    .agg(
        pl.concat_list(
            [
                pl.col("List")
            ]
        )
    )
)
shape: (1, 2)
┌───────┬────────────────────────┐
│ Case  ┆ List                   │
│ ---   ┆ ---                    │
│ str   ┆ list[list[i64]]        │
╞═══════╪════════════════════════╡
│ case1 ┆ [[1, 2, 3], [4, 5, 6]] │
└───────┴────────────────────────┘

The next step is to slice each list so that we get only the nth element of each list. In this example, we want only the first element of each list, corresponding to the 0 in slice(0, 1). Notice that the internal lists are now all only one element each.

(
    df.groupby("Case")
    .agg(
        pl.concat_list(
            [
                pl.col("List")
                .arr.slice(0, 1)
            ]
        )
    )
)
shape: (1, 2)
┌───────┬─────────────────┐
│ Case  ┆ List            │
│ ---   ┆ ---             │
│ str   ┆ list[list[i64]] │
╞═══════╪═════════════════╡
│ case1 ┆ [[1], [4]]      │
└───────┴─────────────────┘

In the last step, we sum the individual elements:

(
    df.groupby("Case")
    .agg(
        pl.concat_list(
            [
                pl.col("List")
                .arr.slice(0, 1)
                .arr.first()
                .sum()
            ]
        )
    )
)
shape: (1, 2)
┌───────┬───────────┐
│ Case  ┆ List      │
│ ---   ┆ ---       │
│ str   ┆ list[i64] │
╞═══════╪═══════════╡
│ case1 ┆ [5]       │
└───────┴───────────┘

To accomplish this for all n elements of each list, we simply write our expressions using a list comprehension, substituting n in our slice(n, 1).

Answered By: cbilot
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.