I am processing a large dataset of non-rectangular satellite imagery with xarray/dask. I know in advance from the footprint of the imagery, that large parts are no-data (since the image does not have a rectangular shape as xarray requires). That means I end up with an xarray that has a lot of chunks where I know that they are all NaN but xarray doesn’t know. Is there a way to exclude these chunks from computation?
Here a pseudo-code example to illustrate:
<code>import xarray as xr
da = xr.open_zarr('src/path', decode_coords='all')
print(da)
# <xarray.DataArray 'band_data' (y: 28981, x: 41721)>
# dask.array<getitem, shape=(28981, 41721), dtype=float32, chunksize=(2048, 2048),
# chunktype=numpy.ndarray>
# Coordinates:
# spatial_ref int64 ...
# * x (x) float64 7.18e+05 7.18e+05 7.18e+05 ... 8.432e+05 8.432e+05
# * y (y) float64 9.736e+06 9.736e+06 ... 9.823e+06 9.823e+06
# this loads and computes each chunk, also the ones where I know in advance that they're all NaN
# ideally the all NaN chunks would not be loaded/computed
result = computationally_expensive_func(da)
...
result.compute()
</code>
<code>import xarray as xr
da = xr.open_zarr('src/path', decode_coords='all')
print(da)
# <xarray.DataArray 'band_data' (y: 28981, x: 41721)>
# dask.array<getitem, shape=(28981, 41721), dtype=float32, chunksize=(2048, 2048),
# chunktype=numpy.ndarray>
# Coordinates:
# spatial_ref int64 ...
# * x (x) float64 7.18e+05 7.18e+05 7.18e+05 ... 8.432e+05 8.432e+05
# * y (y) float64 9.736e+06 9.736e+06 ... 9.823e+06 9.823e+06
# this loads and computes each chunk, also the ones where I know in advance that they're all NaN
# ideally the all NaN chunks would not be loaded/computed
result = computationally_expensive_func(da)
...
result.compute()
</code>
import xarray as xr
da = xr.open_zarr('src/path', decode_coords='all')
print(da)
# <xarray.DataArray 'band_data' (y: 28981, x: 41721)>
# dask.array<getitem, shape=(28981, 41721), dtype=float32, chunksize=(2048, 2048),
# chunktype=numpy.ndarray>
# Coordinates:
# spatial_ref int64 ...
# * x (x) float64 7.18e+05 7.18e+05 7.18e+05 ... 8.432e+05 8.432e+05
# * y (y) float64 9.736e+06 9.736e+06 ... 9.823e+06 9.823e+06
# this loads and computes each chunk, also the ones where I know in advance that they're all NaN
# ideally the all NaN chunks would not be loaded/computed
result = computationally_expensive_func(da)
...
result.compute()