Functional transformations with CoordinateTransformIndex

Functional transformations with CoordinateTransformIndex#

Highlights#

  1. The coordinate variables whose values are described by coordinate transformations – i.e., by a set of formulas describing the relationship between array indices and coordinate labels – are best handled via an xarray.indexes.CoordinateTransformIndex.

  2. Coordinate variables associated with such index are lazy and therefore use very little memory. They may have arbitrary dimensions.

  3. Alignment may be implemented in an optimal way, i.e., based on coordinate transformation parameters rather than on raw coordinate labels.

  4. Xarray exposes an abstract CoordinateTransform class to plug in 3rd-party coordinate transformations with support of dimension and coordinate variable names (see the example below).

See also

CoordinateTransformIndex is often used as a building block by other custom indexes such as xarray.indexes.RangeIndex (see Floating point ranges with RangeIndex) and rasterix.RasterIndex (see Raster affine transforms with rasterix.RasterIndex).

Example (Astronomy)#

As a real-world example, let’s create a custom xarray.indexes.CoordinateTransform that wraps an astropy.wcs.WCS object. This Xarray coordinate transform adapter class simply maps Xarray dimension and coordinate variable names to pixel and world axis names of the shared Python interface for World Coordinate Systems used in Astropy.

Note

This example is taken and adapted from this gist by Stuart Mumford.

It only provides basic integration between Astropy’s WCS and Xarray coordinate transforms. More advanced integration could leverage the slicing capabilities of WCS objects in a custom xarray.indexes.CoordinateTransformIndex subclass.

from collections.abc import Hashable
from typing import Any

import numpy as np
import xarray as xr
from astropy.wcs import WCS


def escape(name):
    return name.replace(".", "_").replace("custom:", "")


class WCSCoordinateTransform(xr.indexes.CoordinateTransform):
    """Lightweight adapter class for the World Coordinate Systems (WCS) API.

    More info: https://docs.astropy.org/en/latest/wcs/wcsapi.html
    """

    def __init__(self, wcs: WCS):
        pixel_axis_names = [
            pan or f"dim{i}"
            for i, pan in enumerate(wcs.pixel_axis_names)
        ]
        world_axis_names = [
            escape(wan or wphy or f"coord{i}")
            for i, (wan, wphy) in enumerate(
                zip(
                    wcs.world_axis_names,
                    wcs.world_axis_physical_types,
                )
            )
        ]
        dim_size = {
            name: size
            for name, size in zip(pixel_axis_names, wcs.array_shape)
        }

        super().__init__(
            world_axis_names, dim_size, dtype=np.dtype(float)
        )
        self.wcs = wcs

    def forward(
        self, dim_positions: dict[str, Any]
    ) -> dict[Hashable, Any]:
        """Perform array -> world coordinate transformation."""

        pixel = [dim_positions[dim] for dim in self.dims]
        world = self.wcs.array_index_to_world_values(*pixel)

        return {name: w for name, w in zip(self.coord_names, world)}

    def reverse(
        self, coord_labels: dict[Hashable, Any]
    ) -> dict[str, Any]:
        """Perform world -> array coordinate reverse transformation."""

        world = [coord_labels[name] for name in self.coord_names]
        pixel = self.wcs.world_to_array_index_values(*world)

        return {name: p for name, p in zip(self.dims, pixel)}

Assigning#

Let’s now create a small function that opens a FITS file with Astropy, creates an Xarray CoordinateTransformIndex and its associated lazy coordinate variables from the WCS object and returns both the data and coordinates as an xarray.DataArray.

from astropy.io import fits
from astropy.utils.data import get_pkg_data_filename


def open_fits_dataarray(filename, item=0):
    hdu = fits.open(filename)[item]
    wcs = WCS(hdu.header)

    transform = WCSCoordinateTransform(wcs)
    index = xr.indexes.CoordinateTransformIndex(transform)
    coords = xr.Coordinates.from_xindex(index)

    return xr.DataArray(
        hdu.data,
        coords=coords,
        dims=transform.dims,
        attrs={"wcs": wcs},
    )

Open a simple image with two celestial axes.

fname = get_pkg_data_filename("galactic_center/gc_2mass_k.fits")

da_2d = open_fits_dataarray(fname)
da_2d
<xarray.DataArray (dim0: 720, dim1: 721)> Size: 2MB
array([[563.158, 540.406, ..., 501.953, 640.751],
       [514.725, 586.917, ..., 504.929, 506.394],
       ...,
       [521.363, 524.339, ..., 546.312, 544.252],
       [519.669, 525.712, ..., 522.279, 543.062]],
      shape=(720, 721), dtype=float32)
Coordinates:
  * pos_eq_ra   (dim0, dim1) float64 4MB 267.0 267.0 267.0 ... 265.8 265.8 265.8
  * pos_eq_dec  (dim0, dim1) float64 4MB -29.43 -29.43 -29.43 ... -28.43 -28.43
Dimensions without coordinates: dim0, dim1
Indexes:
  ┌ pos_eq_ra   CoordinateTransformIndex
  └ pos_eq_dec
Attributes:
    wcs:      WCS Keywords\n\nNumber of WCS axes: 2\nCTYPE : 'RA---TAN' 'DEC-...
# lazy coordinate variables!

da_2d.pos_eq_ra
<xarray.DataArray 'pos_eq_ra' (dim0: 720, dim1: 721)> Size: 4MB
[519120 values with dtype=float64]
Coordinates:
  * pos_eq_ra   (dim0, dim1) float64 4MB 267.0 267.0 267.0 ... 265.8 265.8 265.8
  * pos_eq_dec  (dim0, dim1) float64 4MB -29.43 -29.43 -29.43 ... -28.43 -28.43
Dimensions without coordinates: dim0, dim1
Indexes:
  ┌ pos_eq_ra   CoordinateTransformIndex
  └ pos_eq_dec
da_2d.plot.pcolormesh(
    x="pos_eq_ra",
    y="pos_eq_dec",
    vmax=1300,
    cmap="magma",
);
../_images/ed6a5213a76fd266216b3eece13f2c1189e8d5afcbf0e6e7a0942e404d44dbd3.png

Open a spectral cube with two celestial axes and one spectral axis.

fname = get_pkg_data_filename("l1448/l1448_13co.fits")

da_3d = open_fits_dataarray(fname)
da_3d
<xarray.DataArray (dim0: 53, dim1: 105, dim2: 105)> Size: 2MB
array([[[ 1.565e-01,  1.617e-01, ...,  6.135e-01,  8.006e-01],
        [-6.556e-02,  1.834e-01, ...,  4.520e-01,  5.302e-01],
        ...,
        [-1.023e-01, -2.522e-01, ..., -1.493e-01, -7.321e-02],
        [ 2.357e-01, -1.268e-01, ...,  2.561e-01, -7.383e-02]],

       [[ 1.803e-01,  3.841e-02, ...,  3.219e-01,  3.325e-01],
        [ 5.100e-02, -1.466e-01, ...,  2.972e-01,  4.661e-01],
        ...,
        [-4.534e-04, -1.212e-01, ..., -1.109e-01, -9.869e-02],
        [ 2.601e-01, -7.199e-02, ...,  5.910e-02,  3.332e-01]],

       ...,

       [[ 7.934e-02,  1.493e-01, ...,  2.147e-01,  2.845e-01],
        [ 3.147e-01,  4.547e-01, ...,  1.561e-01,  3.217e-01],
        ...,
        [ 9.095e-01,  8.347e-01, ...,  2.485e-01, -3.963e-02],
        [ 6.973e-01,  4.107e-01, ..., -5.433e-02, -3.194e-01]],

       [[ 1.101e-01,  1.940e-01, ..., -1.001e-01,  1.648e-01],
        [ 9.819e-02,  1.949e-01, ...,  2.342e-01, -2.520e-01],
        ...,
        [ 3.003e-01,  3.390e-01, ...,  1.896e-01,  2.429e-01],
        [ 4.334e-01,  5.407e-01, ..., -1.176e-01, -4.120e-02]]],
      shape=(53, 105, 105), dtype='>f4')
Coordinates:
  * pos_eq_ra               (dim0, dim1, dim2) float64 5MB 51.74 51.73 ... 50.92
  * pos_eq_dec              (dim0, dim1, dim2) float64 5MB 30.3 30.3 ... 30.97
  * spect_dopplerVeloc_opt  (dim0, dim1, dim2) float64 5MB 2.528e+03 ... 5.98...
Dimensions without coordinates: dim0, dim1, dim2
Indexes:
  ┌ pos_eq_ra               CoordinateTransformIndex
  │ pos_eq_dec
  └ spect_dopplerVeloc_opt
Attributes:
    wcs:      WCS Keywords\n\nNumber of WCS axes: 3\nCTYPE : 'RA---SFL' 'DEC-...