Why does xarray.Dataset.where expand coordinates?
Question:
I’m just not understanding this behavior in xarray, and I probably just don’t understand the broadcasting xarray does. I’ve made a contrived example illustrating the issue.
import numpy as np
import xarray as xr
# These are coordinates
years = xr.DataArray(np.arange(2018, 2021), dims="year")
ids = xr.DataArray(np.arange(1, 4), dims="id")
# These are data with different coordinates
year_data = xr.DataArray(np.arange(18, 21), dims="year", coords={"year": years})
id_data = xr.DataArray(['a', 'b', 'c'], dims="id", coords={"id": ids})
comb_data = xr.DataArray(np.arange(9).reshape(3, 3), dims=["year", "id"], coords={"year": years, "id": ids})
# Make a dataset
ds = xr.Dataset(data_vars={"comb_data": comb_data, "id_data": id_data, "year_data": year_data})
This makes:
<xarray.Dataset>
Dimensions: (year: 3, id: 3)
Coordinates:
* year (year) int64 2018 2019 2020
* id (id) int64 1 2 3
Data variables:
comb_data (year, id) int64 0 1 2 3 4 5 6 7 8
id_data (id) <U1 'a' 'b' 'c'
year_data (year) int64 18 19 20
This is what I want, with a 2 data variables that refer to different coordinates and 1 data variable that uses both. I need to set some data to 0, so I use where.
ds.where(ds.coords["id"] == 2, 0)
<xarray.Dataset>
Dimensions: (year: 3, id: 3)
Coordinates:
* year (year) int64 2018 2019 2020
* id (id) int64 1 2 3
Data variables:
comb_data (year, id) int64 0 2 0 0 5 0 0 8 0
id_data (id) object 0 'b' 0
year_data (year, id) int64 0 18 0 0 19 0 0 20 0
Now the year_data
dimension includes id
and has created data with no meaning. I just need to ignore the dimensions that aren’t involved in this. I can delete the extraneous data after the fact but that doesn’t feel right. Is there a better way to do this?
Answers:
xr.Dataset.where
will always broadcast the variables in the dataset against the supplied argument. Since you told it to mask all the data in the dataset except where ds["id"] == 2
, all objects in the dataset are automatically broadcast against the id
dimension.
If you don’t want this to happen, you have two options. The first is a good practice in general – unless you really want to carry out an operation across all variables in a dataset, do your operations with the specific variables you want to work with:
# this will only modify the comb_data array
ds["comb_data"] = ds["comb_data"].where(ds["id"] == 2)
See the docs on broadcasting and automatic alignment for more info about xarray’s computing rules.
Another option, if "id_data"
and "year_data"
are more metadata than true data variables, is to set them as non-dimension coordinates. Then, any computation will skip over these:
In [4]: ds = ds.set_coords(["id_data", "year_data"])
In [5]: ds
Out[5]:
<xarray.Dataset>
Dimensions: (year: 3, id: 3)
Coordinates:
* year (year) int64 2018 2019 2020
* id (id) int64 1 2 3
id_data (id) <U1 'a' 'b' 'c'
year_data (year) int64 18 19 20
Data variables:
comb_data (year, id) int64 0 1 2 3 4 5 6 7 8
In [6]: ds.where(ds["id"] == 2)
Out[6]:
<xarray.Dataset>
Dimensions: (year: 3, id: 3)
Coordinates:
* year (year) int64 2018 2019 2020
* id (id) int64 1 2 3
id_data (id) <U1 'a' 'b' 'c'
year_data (year) int64 18 19 20
Data variables:
comb_data (year, id) float64 nan 1.0 nan nan 4.0 nan nan 7.0 nan
See the docs on coordinates for more information about how coordinates are handled in computation.
I’m just not understanding this behavior in xarray, and I probably just don’t understand the broadcasting xarray does. I’ve made a contrived example illustrating the issue.
import numpy as np
import xarray as xr
# These are coordinates
years = xr.DataArray(np.arange(2018, 2021), dims="year")
ids = xr.DataArray(np.arange(1, 4), dims="id")
# These are data with different coordinates
year_data = xr.DataArray(np.arange(18, 21), dims="year", coords={"year": years})
id_data = xr.DataArray(['a', 'b', 'c'], dims="id", coords={"id": ids})
comb_data = xr.DataArray(np.arange(9).reshape(3, 3), dims=["year", "id"], coords={"year": years, "id": ids})
# Make a dataset
ds = xr.Dataset(data_vars={"comb_data": comb_data, "id_data": id_data, "year_data": year_data})
This makes:
<xarray.Dataset>
Dimensions: (year: 3, id: 3)
Coordinates:
* year (year) int64 2018 2019 2020
* id (id) int64 1 2 3
Data variables:
comb_data (year, id) int64 0 1 2 3 4 5 6 7 8
id_data (id) <U1 'a' 'b' 'c'
year_data (year) int64 18 19 20
This is what I want, with a 2 data variables that refer to different coordinates and 1 data variable that uses both. I need to set some data to 0, so I use where.
ds.where(ds.coords["id"] == 2, 0)
<xarray.Dataset>
Dimensions: (year: 3, id: 3)
Coordinates:
* year (year) int64 2018 2019 2020
* id (id) int64 1 2 3
Data variables:
comb_data (year, id) int64 0 2 0 0 5 0 0 8 0
id_data (id) object 0 'b' 0
year_data (year, id) int64 0 18 0 0 19 0 0 20 0
Now the year_data
dimension includes id
and has created data with no meaning. I just need to ignore the dimensions that aren’t involved in this. I can delete the extraneous data after the fact but that doesn’t feel right. Is there a better way to do this?
xr.Dataset.where
will always broadcast the variables in the dataset against the supplied argument. Since you told it to mask all the data in the dataset except where ds["id"] == 2
, all objects in the dataset are automatically broadcast against the id
dimension.
If you don’t want this to happen, you have two options. The first is a good practice in general – unless you really want to carry out an operation across all variables in a dataset, do your operations with the specific variables you want to work with:
# this will only modify the comb_data array
ds["comb_data"] = ds["comb_data"].where(ds["id"] == 2)
See the docs on broadcasting and automatic alignment for more info about xarray’s computing rules.
Another option, if "id_data"
and "year_data"
are more metadata than true data variables, is to set them as non-dimension coordinates. Then, any computation will skip over these:
In [4]: ds = ds.set_coords(["id_data", "year_data"])
In [5]: ds
Out[5]:
<xarray.Dataset>
Dimensions: (year: 3, id: 3)
Coordinates:
* year (year) int64 2018 2019 2020
* id (id) int64 1 2 3
id_data (id) <U1 'a' 'b' 'c'
year_data (year) int64 18 19 20
Data variables:
comb_data (year, id) int64 0 1 2 3 4 5 6 7 8
In [6]: ds.where(ds["id"] == 2)
Out[6]:
<xarray.Dataset>
Dimensions: (year: 3, id: 3)
Coordinates:
* year (year) int64 2018 2019 2020
* id (id) int64 1 2 3
id_data (id) <U1 'a' 'b' 'c'
year_data (year) int64 18 19 20
Data variables:
comb_data (year, id) float64 nan 1.0 nan nan 4.0 nan nan 7.0 nan
See the docs on coordinates for more information about how coordinates are handled in computation.