Sample from each group in polars dataframe?

Question:

I’m looking for a function along the lines of

df.groupby('column').agg(sample(10))

so that I can take ten or so randomly-selected elements from each group.

This is specifically so I can read in a LazyFrame and work with a small sample of each group as opposed to the entire dataframe.

Update:

One approximate solution is:

df = lf.groupby('column').agg(
        pl.all().sample(.001)
    )
df = df.explode(df.columns[1:])

Update 2

That approximate solution is just the same as sampling the whole dataframe and doing a groupby after. No good.

Asked By: user6268172

||

Answers:

We can try making our own groupby-like functionality and sampling from the filtered subsets.

samples = []
cats = df.get_column('column').unique().to_list()
for cat in cats:
    samples.append(df.filter(pl.col('column') == cat).sample(10))
samples = pl.concat(samples)

Found partition_by in the documentation, this should be more efficient, since at least the groups are made with the api and in single pass of the dataframe. Sampling each group is still linear unfortunately.

pl.concat([x.sample(10) for x in df.partition_by(groups="column")])

Third attempt, sampling indices:

import numpy as np
import random

indices = df.groupby("group").agg(pl.col("value").agg_groups()).get_column("value").to_list()
sampled = np.array([random.sample(x, 10) for x in indices]).flatten()
df[sampled]
Answered By: BeRT2me

Let start with some dummy data:

n = 100
seed = 0
df = pl.DataFrame(
    {
        "groups": (pl.int_range(0, n, eager=True) % 5).shuffle(seed=seed),
        "values": pl.int_range(0, n, eager=True).shuffle(seed=seed)
    }
)
df
shape: (100, 2)
┌────────┬────────┐
│ groups ┆ values │
│ ---    ┆ ---    │
│ i64    ┆ i64    │
╞════════╪════════╡
│ 0      ┆ 55     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 0      ┆ 40     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 2      ┆ 57     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4      ┆ 99     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ ...    ┆ ...    │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 2      ┆ 87     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 1      ┆ 96     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 3      ┆ 43     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4      ┆ 44     │
└────────┴────────┘

This gives us 100 / 5, is 5 groups of 20 elements. Let’s verify that:

df.groupby("groups").agg(pl.count())
shape: (5, 2)
┌────────┬───────┐
│ groups ┆ count │
│ ---    ┆ ---   │
│ i64    ┆ u32   │
╞════════╪═══════╡
│ 1      ┆ 20    │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3      ┆ 20    │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 4      ┆ 20    │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2      ┆ 20    │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 0      ┆ 20    │
└────────┴───────┘

Sample our data

Now we are going to use a window function to take a sample of our data.

df.filter(
    pl.int_range(0, pl.count()).shuffle().over("groups") < 10
)
shape: (50, 2)
┌────────┬────────┐
│ groups ┆ values │
│ ---    ┆ ---    │
│ i64    ┆ i64    │
╞════════╪════════╡
│ 0      ┆ 85     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 0      ┆ 0      │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4      ┆ 84     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4      ┆ 19     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ ...    ┆ ...    │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 2      ┆ 87     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 1      ┆ 96     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 3      ┆ 43     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4      ┆ 44     │
└────────┴────────┘

For every group in over("group") the pl.int_range(0, pl.count()) expression creates an index row. We then shuffle that range so that we take a sample and not a slice. Then we only want to take the index values that are lower than 10. This creates a boolean mask that we can pass to the filter method.

Answered By: ritchie46

This worked better for me:

sampled_df = pl.concat(
    df.sample(0.001) for df in 
    df.partition_by(["column"], include_key=True)
)

The problem with .agg(pl.col("column").sample(2) was that it seemed to select different values for each column. What I needed was randomly selected rows.

Answered By: santon
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.