Why does xarray.Dataset.where expand coordinates?

Question:

I’m just not understanding this behavior in xarray, and I probably just don’t understand the broadcasting xarray does. I’ve made a contrived example illustrating the issue.

import numpy as np
import xarray as xr

# These are coordinates
years = xr.DataArray(np.arange(2018, 2021), dims="year")
ids = xr.DataArray(np.arange(1, 4), dims="id")

# These are data with different coordinates
year_data = xr.DataArray(np.arange(18, 21), dims="year", coords={"year": years})
id_data = xr.DataArray(['a', 'b', 'c'], dims="id", coords={"id": ids})
comb_data = xr.DataArray(np.arange(9).reshape(3, 3), dims=["year", "id"], coords={"year": years, "id": ids})

# Make a dataset
ds = xr.Dataset(data_vars={"comb_data": comb_data, "id_data": id_data, "year_data": year_data})

This makes:

<xarray.Dataset>
Dimensions:    (year: 3, id: 3)
Coordinates:
  * year       (year) int64 2018 2019 2020
  * id         (id) int64 1 2 3
Data variables:
    comb_data  (year, id) int64 0 1 2 3 4 5 6 7 8
    id_data    (id) <U1 'a' 'b' 'c'
    year_data  (year) int64 18 19 20

This is what I want, with a 2 data variables that refer to different coordinates and 1 data variable that uses both. I need to set some data to 0, so I use where.

ds.where(ds.coords["id"] == 2, 0)

<xarray.Dataset>
Dimensions:    (year: 3, id: 3)
Coordinates:
  * year       (year) int64 2018 2019 2020
  * id         (id) int64 1 2 3
Data variables:
    comb_data  (year, id) int64 0 2 0 0 5 0 0 8 0
    id_data    (id) object 0 'b' 0
    year_data  (year, id) int64 0 18 0 0 19 0 0 20 0

Now the year_data dimension includes id and has created data with no meaning. I just need to ignore the dimensions that aren’t involved in this. I can delete the extraneous data after the fact but that doesn’t feel right. Is there a better way to do this?

Asked By: rlank

||

Answers:

xr.Dataset.where will always broadcast the variables in the dataset against the supplied argument. Since you told it to mask all the data in the dataset except where ds["id"] == 2, all objects in the dataset are automatically broadcast against the id dimension.

If you don’t want this to happen, you have two options. The first is a good practice in general – unless you really want to carry out an operation across all variables in a dataset, do your operations with the specific variables you want to work with:

# this will only modify the comb_data array
ds["comb_data"] = ds["comb_data"].where(ds["id"] == 2)

See the docs on broadcasting and automatic alignment for more info about xarray’s computing rules.

Another option, if "id_data" and "year_data" are more metadata than true data variables, is to set them as non-dimension coordinates. Then, any computation will skip over these:

In [4]: ds = ds.set_coords(["id_data", "year_data"])

In [5]: ds
Out[5]:
<xarray.Dataset>
Dimensions:    (year: 3, id: 3)
Coordinates:
  * year       (year) int64 2018 2019 2020
  * id         (id) int64 1 2 3
    id_data    (id) <U1 'a' 'b' 'c'
    year_data  (year) int64 18 19 20
Data variables:
    comb_data  (year, id) int64 0 1 2 3 4 5 6 7 8

In [6]: ds.where(ds["id"] == 2)
Out[6]:
<xarray.Dataset>
Dimensions:    (year: 3, id: 3)
Coordinates:
  * year       (year) int64 2018 2019 2020
  * id         (id) int64 1 2 3
    id_data    (id) <U1 'a' 'b' 'c'
    year_data  (year) int64 18 19 20
Data variables:
    comb_data  (year, id) float64 nan 1.0 nan nan 4.0 nan nan 7.0 nan

See the docs on coordinates for more information about how coordinates are handled in computation.

Answered By: Michael Delgado
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.