Polars: Replace parts of dataframe with other parts of dataframe

Question:

I’m looking for an efficient way to copy / replace parts of a dataframe with other parts of the same dataframe in Polars.

For instance, in the following minimal example dataframe

pl.DataFrame({
  "year": [2020,2021,2020,2021],
  "district_id": [1,2,1,2],
  "distribution_id": [1, 1, 2, 2],
  "var_1": [1,2,0.1,0.3],
  "var_N": [1,2,0.3,0.5],
  "unrelated_var": [0.2,0.5,0.3,0.7],
})

I’d like to replace all column values of "var_1" & "var_N" where the "distribution_id" = 2 with the corresponding values where the "distribution_id" = 1.

This is the desired result:

pl.DataFrame({
  "year": [2020,2021,2020,2021],
  "district_id": [1,2,1,2],
  "distribution_id": [1, 1, 2, 2],
  "var_1": [1,2,1,2],
  "var_N": [1,2,1,2],
  "unrelated_var": [0.2,0.5,0.3,0.7],
})

I tried to use a "when" expression, but it fails with "polars.exceptions.ShapeError: shapes of self, mask and other are not suitable for zip_with operation"

df = df.with_columns([
  pl.when(pl.col("distribution_id") == 2).then(df.filter(pl.col("distribution_id") == 1).otherwise(pl.col(col)).alias(col) for col in columns_to_copy
  ]
)

Here’s what I used to do with SQLAlchemy:

table_alias = table.alias("table_alias")
stmt = table.update().
    where(table.c.year == table_alias.c.year).
    where(table.c.d_id == table_alias.c.d_id).
    where(table_alias.c.distribution_id == 1).
    where(table.c.distribution_id == 2).
    values(var_1=table_alias.c.var_1,
           var_n=table_alias.c.var_n)

Thanks a lot for you help!

Asked By: Christoph Pahmeyer

||

Answers:

In this concrete example, this can be solved with a simple pl.Expr.forward_fill after setting the var_* variables to None (which can be done with a pl.when().then() expression, missing the otherwise() term).

import polars.selectors as cs

(
    df
    .with_columns(
        pl.when(pl.col("distribution_id") == 1).then(cs.starts_with("var_"))
    )
    .sort("year", "district_id")
    .with_columns(
        cs.starts_with("var_").forward_fill()
    )
)
shape: (4, 6)
┌──────┬─────────────┬─────────────────┬───────┬───────┬───────────────┐
│ year ┆ district_id ┆ distribution_id ┆ var_1 ┆ var_N ┆ unrelated_var │
│ ---  ┆ ---         ┆ ---             ┆ ---   ┆ ---   ┆ ---           │
│ i64  ┆ i64         ┆ i64             ┆ f64   ┆ f64   ┆ f64           │
╞══════╪═════════════╪═════════════════╪═══════╪═══════╪═══════════════╡
│ 2020 ┆ 1           ┆ 1               ┆ 1.0   ┆ 1.0   ┆ 0.2           │
│ 2020 ┆ 1           ┆ 2               ┆ 1.0   ┆ 1.0   ┆ 0.3           │
│ 2021 ┆ 2           ┆ 1               ┆ 2.0   ┆ 2.0   ┆ 0.5           │
│ 2021 ┆ 2           ┆ 2               ┆ 2.0   ┆ 2.0   ┆ 0.7           │
└──────┴─────────────┴─────────────────┴───────┴───────┴───────────────┘
Answered By: Hericks

As a more generally applicable approach (e.g. if the distribution ids to be filled / to fill with are not consecutive), we can

  1. left-join with a filtered dataframe of the values to fill with,
  2. use a pl.when().then().otherwise() construct to select the column value / joined column values depending on the value of distribution_id,
  3. drop the joined column.
import polars.selectors as cs

(
    df
    .join(
        df.filter(pl.col("distribution_id") == 1),
        on=["year", "district_id"],
        how="left",
    )
    .with_columns(
        pl.when(pl.col("distribution_id") != 2).then(col).otherwise(col + "_right")
        for col in ["var_1", "var_N"]
    )
    .drop(cs.ends_with("_right"))
)
shape: (4, 6)
┌──────┬─────────────┬─────────────────┬───────┬───────┬───────────────┐
│ year ┆ district_id ┆ distribution_id ┆ var_1 ┆ var_N ┆ unrelated_var │
│ ---  ┆ ---         ┆ ---             ┆ ---   ┆ ---   ┆ ---           │
│ i64  ┆ i64         ┆ i64             ┆ f64   ┆ f64   ┆ f64           │
╞══════╪═════════════╪═════════════════╪═══════╪═══════╪═══════════════╡
│ 2020 ┆ 1           ┆ 1               ┆ 1.0   ┆ 1.0   ┆ 0.2           │
│ 2021 ┆ 2           ┆ 1               ┆ 2.0   ┆ 2.0   ┆ 0.5           │
│ 2020 ┆ 1           ┆ 2               ┆ 1.0   ┆ 1.0   ┆ 0.3           │
│ 2021 ┆ 2           ┆ 2               ┆ 2.0   ┆ 2.0   ┆ 0.7           │
└──────┴─────────────┴─────────────────┴───────┴───────┴───────────────┘
Answered By: Hericks

You could filter the 1 columns, change their id to 2 and discard the unneeded columns.

df.filter(distribution_id = 1).select(
   "year", "district_id", "^var_.+$", distribution_id = pl.lit(2, pl.Int64)
)
shape: (2, 5)
┌──────┬─────────────┬───────┬───────┬─────────────────┐
│ year ┆ district_id ┆ var_1 ┆ var_N ┆ distribution_id │
│ ---  ┆ ---         ┆ ---   ┆ ---   ┆ ---             │
│ i64  ┆ i64         ┆ f64   ┆ f64   ┆ i64             │
╞══════╪═════════════╪═══════╪═══════╪═════════════════╡
│ 2020 ┆ 1           ┆ 1.0   ┆ 1.0   ┆ 2               │
│ 2021 ┆ 2           ┆ 2.0   ┆ 2.0   ┆ 2               │
└──────┴─────────────┴───────┴───────┴─────────────────┘
  • (note: "^var_.+$" selects columns by regex, but selectors can be used if preferred.)

With the data "aligned", you can pass it to .update()

df.update(
   df.filter(distribution_id = 1)
     .select("year", "district_id", "^var_.+$", distribution_id = pl.lit(2, pl.Int64)),
   on=["year", "district_id", "distribution_id"]
)
shape: (4, 6)
┌──────┬─────────────┬─────────────────┬───────┬───────┬───────────────┐
│ year ┆ district_id ┆ distribution_id ┆ var_1 ┆ var_N ┆ unrelated_var │
│ ---  ┆ ---         ┆ ---             ┆ ---   ┆ ---   ┆ ---           │
│ i64  ┆ i64         ┆ i64             ┆ f64   ┆ f64   ┆ f64           │
╞══════╪═════════════╪═════════════════╪═══════╪═══════╪═══════════════╡
│ 2020 ┆ 1           ┆ 1               ┆ 1.0   ┆ 1.0   ┆ 0.2           │
│ 2021 ┆ 2           ┆ 1               ┆ 2.0   ┆ 2.0   ┆ 0.5           │
│ 2020 ┆ 1           ┆ 2               ┆ 1.0   ┆ 1.0   ┆ 0.3           │
│ 2021 ┆ 2           ┆ 2               ┆ 2.0   ┆ 2.0   ┆ 0.7           │
└──────┴─────────────┴─────────────────┴───────┴───────┴───────────────┘
Answered By: jqurious

Here’s an approach that more closely resembles what you attempted to do:

df.with_columns(
    pl.when(pl.col('distribution_id')==2)
    .then(
        pl.col('^var.*$')
        .filter(pl.col('distribution_id')==1)
        .first()
        .over('year','district_id')
        )
    .otherwise(pl.col('^var.*$'))
)

It starts off with a when because you only want to replace the values when distribution_id==2. Then it takes the column (multiple columns with regex) that you want to change it filters it by the values that you want to source from. If you just stop at the filter then you’re trying to feed a column that is 4 rows long with a column that is 2 rows long which, of course, doesn’t work. To make it work, you tell it that you just want the first value for each group that you care about. The group goes in over like a df.group_by.agg but over goes last in the chain. Since, presumably, year/district_id represents a unique primary key, there’s only one value per group so you could just as well use first/last/max/min/etc, it just needs to be some aggregate function.

Answered By: Dean MacGregor
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.