Functional transformations with CoordinateTransformIndex#
Highlights#
The coordinate variables whose values are described by coordinate transformations – i.e., by a set of formulas describing the relationship between array indices and coordinate labels – are best handled via an
xarray.indexes.CoordinateTransformIndex.Coordinate variables associated with such index are lazy and therefore use very little memory. They may have arbitrary dimensions.
Alignment may be implemented in an optimal way, i.e., based on coordinate transformation parameters rather than on raw coordinate labels.
Xarray exposes an abstract
CoordinateTransformclass to plug in 3rd-party coordinate transformations with support of dimension and coordinate variable names (see the example below).
See also
CoordinateTransformIndex is often used as a building block by other
custom indexes such as xarray.indexes.RangeIndex (see
Floating point ranges with RangeIndex) and rasterix.RasterIndex (see
Raster affine transforms with rasterix.RasterIndex).
Example (Astronomy)#
As a real-world example, let’s create a custom
xarray.indexes.CoordinateTransform that wraps an
astropy.wcs.WCS object. This Xarray coordinate transform adapter
class simply maps Xarray dimension and coordinate variable names to pixel
and world axis names of the shared Python interface for World Coordinate
Systems used in Astropy.
Note
This example is taken and adapted from this gist by Stuart Mumford.
It only provides basic integration between Astropy’s WCS and Xarray
coordinate transforms. More advanced integration could leverage the
slicing capabilities of WCS objects
in a custom xarray.indexes.CoordinateTransformIndex subclass.
from collections.abc import Hashable
from typing import Any
import numpy as np
import xarray as xr
from astropy.wcs import WCS
def escape(name):
return name.replace(".", "_").replace("custom:", "")
class WCSCoordinateTransform(xr.indexes.CoordinateTransform):
"""Lightweight adapter class for the World Coordinate Systems (WCS) API.
More info: https://docs.astropy.org/en/latest/wcs/wcsapi.html
"""
def __init__(self, wcs: WCS):
pixel_axis_names = [
pan or f"dim{i}"
for i, pan in enumerate(wcs.pixel_axis_names)
]
world_axis_names = [
escape(wan or wphy or f"coord{i}")
for i, (wan, wphy) in enumerate(
zip(
wcs.world_axis_names,
wcs.world_axis_physical_types,
)
)
]
dim_size = {
name: size
for name, size in zip(pixel_axis_names, wcs.array_shape)
}
super().__init__(
world_axis_names, dim_size, dtype=np.dtype(float)
)
self.wcs = wcs
def forward(
self, dim_positions: dict[str, Any]
) -> dict[Hashable, Any]:
"""Perform array -> world coordinate transformation."""
pixel = [dim_positions[dim] for dim in self.dims]
world = self.wcs.array_index_to_world_values(*pixel)
return {name: w for name, w in zip(self.coord_names, world)}
def reverse(
self, coord_labels: dict[Hashable, Any]
) -> dict[str, Any]:
"""Perform world -> array coordinate reverse transformation."""
world = [coord_labels[name] for name in self.coord_names]
pixel = self.wcs.world_to_array_index_values(*world)
return {name: p for name, p in zip(self.dims, pixel)}
Assigning#
Let’s now create a small function that opens a FITS file with Astropy, creates
an Xarray CoordinateTransformIndex and its
associated lazy coordinate variables from the WCS
object and returns both the data and coordinates as an
xarray.DataArray.
from astropy.io import fits
from astropy.utils.data import get_pkg_data_filename
def open_fits_dataarray(filename, item=0):
hdu = fits.open(filename)[item]
wcs = WCS(hdu.header)
transform = WCSCoordinateTransform(wcs)
index = xr.indexes.CoordinateTransformIndex(transform)
coords = xr.Coordinates.from_xindex(index)
return xr.DataArray(
hdu.data,
coords=coords,
dims=transform.dims,
attrs={"wcs": wcs},
)
Open a simple image with two celestial axes.
fname = get_pkg_data_filename("galactic_center/gc_2mass_k.fits")
da_2d = open_fits_dataarray(fname)
da_2d
<xarray.DataArray (dim0: 720, dim1: 721)> Size: 2MB
array([[563.158, 540.406, ..., 501.953, 640.751],
[514.725, 586.917, ..., 504.929, 506.394],
...,
[521.363, 524.339, ..., 546.312, 544.252],
[519.669, 525.712, ..., 522.279, 543.062]],
shape=(720, 721), dtype=float32)
Coordinates:
* pos_eq_ra (dim0, dim1) float64 4MB 267.0 267.0 267.0 ... 265.8 265.8 265.8
* pos_eq_dec (dim0, dim1) float64 4MB -29.43 -29.43 -29.43 ... -28.43 -28.43
Dimensions without coordinates: dim0, dim1
Indexes:
┌ pos_eq_ra CoordinateTransformIndex
└ pos_eq_dec
Attributes:
wcs: WCS Keywords\n\nNumber of WCS axes: 2\nCTYPE : 'RA---TAN' 'DEC-...# lazy coordinate variables!
da_2d.pos_eq_ra
<xarray.DataArray 'pos_eq_ra' (dim0: 720, dim1: 721)> Size: 4MB [519120 values with dtype=float64] Coordinates: * pos_eq_ra (dim0, dim1) float64 4MB 267.0 267.0 267.0 ... 265.8 265.8 265.8 * pos_eq_dec (dim0, dim1) float64 4MB -29.43 -29.43 -29.43 ... -28.43 -28.43 Dimensions without coordinates: dim0, dim1 Indexes: ┌ pos_eq_ra CoordinateTransformIndex └ pos_eq_dec
da_2d.plot.pcolormesh(
x="pos_eq_ra",
y="pos_eq_dec",
vmax=1300,
cmap="magma",
);
Open a spectral cube with two celestial axes and one spectral axis.
fname = get_pkg_data_filename("l1448/l1448_13co.fits")
da_3d = open_fits_dataarray(fname)
da_3d
<xarray.DataArray (dim0: 53, dim1: 105, dim2: 105)> Size: 2MB
array([[[ 1.565e-01, 1.617e-01, ..., 6.135e-01, 8.006e-01],
[-6.556e-02, 1.834e-01, ..., 4.520e-01, 5.302e-01],
...,
[-1.023e-01, -2.522e-01, ..., -1.493e-01, -7.321e-02],
[ 2.357e-01, -1.268e-01, ..., 2.561e-01, -7.383e-02]],
[[ 1.803e-01, 3.841e-02, ..., 3.219e-01, 3.325e-01],
[ 5.100e-02, -1.466e-01, ..., 2.972e-01, 4.661e-01],
...,
[-4.534e-04, -1.212e-01, ..., -1.109e-01, -9.869e-02],
[ 2.601e-01, -7.199e-02, ..., 5.910e-02, 3.332e-01]],
...,
[[ 7.934e-02, 1.493e-01, ..., 2.147e-01, 2.845e-01],
[ 3.147e-01, 4.547e-01, ..., 1.561e-01, 3.217e-01],
...,
[ 9.095e-01, 8.347e-01, ..., 2.485e-01, -3.963e-02],
[ 6.973e-01, 4.107e-01, ..., -5.433e-02, -3.194e-01]],
[[ 1.101e-01, 1.940e-01, ..., -1.001e-01, 1.648e-01],
[ 9.819e-02, 1.949e-01, ..., 2.342e-01, -2.520e-01],
...,
[ 3.003e-01, 3.390e-01, ..., 1.896e-01, 2.429e-01],
[ 4.334e-01, 5.407e-01, ..., -1.176e-01, -4.120e-02]]],
shape=(53, 105, 105), dtype='>f4')
Coordinates:
* pos_eq_ra (dim0, dim1, dim2) float64 5MB 51.74 51.73 ... 50.92
* pos_eq_dec (dim0, dim1, dim2) float64 5MB 30.3 30.3 ... 30.97
* spect_dopplerVeloc_opt (dim0, dim1, dim2) float64 5MB 2.528e+03 ... 5.98...
Dimensions without coordinates: dim0, dim1, dim2
Indexes:
┌ pos_eq_ra CoordinateTransformIndex
│ pos_eq_dec
└ spect_dopplerVeloc_opt
Attributes:
wcs: WCS Keywords\n\nNumber of WCS axes: 3\nCTYPE : 'RA---SFL' 'DEC-...