Nearest neighbors with NDPointIndex#

Tip

New to spatial trees? Start with Tree-Based Indexing to learn how tree structures enable fast nearest-neighbor search, and when you need alternatives like Ball trees for geographic data.

Highlights#

  1. xarray.indexes.NDPointIndex is useful for dealing with n-dimensional coordinate variables representing irregular data.

  2. It enables efficient point-wise (nearest-neighbors) data selection using Xarray’s advanced indexing capabilities.

  3. By default, a scipy.spatial.KDTree is used under the hood for fast lookup of point data. Although experimental, it is possible to plug in alternative structures to NDPointIndex (See Advanced usage).

Basic usage#

A common task: you have scattered observations and want to color a regular grid based on the nearest observation at each grid point. This requires finding the nearest neighbor for every grid cell.

Let’s create some scattered measurements:

import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
shape = (5, 10)
xx = xr.DataArray(
    np.random.uniform(0, 10, size=shape), dims=("y", "x")
)
yy = xr.DataArray(
    np.random.uniform(0, 5, size=shape), dims=("y", "x")
)
data = (xx - 5) ** 2 + (yy - 2.5) ** 2

ds = xr.Dataset(data_vars={"data": data}, coords={"xx": xx, "yy": yy})
ds
<xarray.Dataset> Size: 1kB
Dimensions:  (y: 5, x: 10)
Coordinates:
    xx       (y, x) float64 400B 1.636 5.304 8.267 8.642 ... 4.706 2.265 9.191
    yy       (y, x) float64 400B 3.925 4.306 2.534 4.908 ... 2.678 4.991 2.931
Dimensions without coordinates: y, x
Data variables:
    data     (y, x) float64 400B 13.35 3.356 10.67 19.06 ... 0.1181 13.69 17.75

Hide code cell source

# Show the scattered source points and the empty grid we want to fill
fig, ax = plt.subplots(figsize=(10, 5))

# Draw empty grid lines (the grid we want to color)
for x in range(11):
    ax.axvline(x - 0.5, color='lightgray', lw=0.8, zorder=1)
for y in range(6):
    ax.axhline(y - 0.5, color='lightgray', lw=0.8, zorder=1)

# Plot the scattered source points with colors
scatter = ax.scatter(ds.xx.values.ravel(), ds.yy.values.ravel(),
                     c=ds.data.values.ravel(), s=60, cmap='viridis',
                     edgecolors='black', linewidths=0.5, zorder=5)
plt.colorbar(scatter, ax=ax, label='data')

ax.set_xlim(-0.5, 10.5)
ax.set_ylim(-0.5, 5.5)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title('Scattered observations + empty grid (goal: color each grid cell)')
plt.tight_layout()
plt.show()
../_images/3e7ae6a9159479746b3facd4ecd75c9c5f88794fdd090e187c6771c8c38b785a.png

Assigning the index#

To enable nearest-neighbor lookups, we attach an NDPointIndex to our coordinate variables using set_xindex. This builds a KD-tree from the coordinate values, which will be used for efficient spatial queries.

The tuple ("xx", "yy") tells xarray which coordinates to combine into a multi-dimensional index.

Note

The tree is built once when you call set_xindex. This has a small upfront cost, but all subsequent queries are fast. This makes NDPointIndex ideal when you need to query the same dataset many times.

ds_index = ds.set_xindex(("xx", "yy"), xr.indexes.NDPointIndex)
ds_index
<xarray.Dataset> Size: 1kB
Dimensions:  (y: 5, x: 10)
Coordinates:
  * xx       (y, x) float64 400B 1.636 5.304 8.267 8.642 ... 4.706 2.265 9.191
  * yy       (y, x) float64 400B 3.925 4.306 2.534 4.908 ... 2.678 4.991 2.931
Dimensions without coordinates: y, x
Data variables:
    data     (y, x) float64 400B 13.35 3.356 10.67 19.06 ... 0.1181 13.69 17.75
Indexes:
  ┌ xx       NDPointIndex (ScipyKDTreeAdapter)
  └ yy

Querying the index#

Now we can query for nearest neighbors using familiar xarray syntax. Under the hood, .sel(..., method="nearest") calls KDTree.query to efficiently find the closest point:

ds_index.sel(xx=3.4, yy=4.2, method="nearest")
<xarray.Dataset> Size: 24B
Dimensions:  ()
Coordinates:
    xx       float64 8B 2.674
    yy       float64 8B 4.56
Data variables:
    data     float64 8B 9.654

Hide code cell source

query_x, query_y = 3.4, 4.2
result = ds_index.sel(xx=query_x, yy=query_y, method="nearest")
nearest_x = result.xx.item()
nearest_y = result.yy.item()

fig, ax = plt.subplots(figsize=(10, 5))

# All points at low alpha
ax.scatter(
    ds.xx.values.ravel(), ds.yy.values.ravel(),
    c='steelblue', s=40, alpha=0.2, edgecolors='gray', linewidths=0.5, zorder=3,
)

# Line connecting query to nearest
ax.annotate(
    '', xy=(nearest_x, nearest_y), xytext=(query_x, query_y),
    arrowprops=dict(arrowstyle='-', color='black', lw=1.5, ls='--'),
    zorder=4,
)

# Nearest point highlighted
ax.scatter(
    nearest_x, nearest_y,
    c='steelblue', s=120, edgecolors='black', linewidths=1.5, zorder=6,
    label=f'nearest ({nearest_x:.1f}, {nearest_y:.1f})',
)

# Query point
ax.scatter(
    query_x, query_y,
    c='red', s=120, marker='X', linewidths=0.5, edgecolors='darkred', zorder=6,
    label=f'query ({query_x}, {query_y})',
)

ax.set_xlim(-0.5, 10.5)
ax.set_ylim(-0.5, 5.5)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title('Nearest-neighbor lookup')
ax.legend()
plt.tight_layout()
plt.show()
../_images/772cbc5c05ec3400ef131ac66861a82b4dd2d876dcedbeeb9c6d5e87c5c0bbe9.png

Assigning values from scattered points to a grid#

Sometimes you need to transfer scattered observations onto a regular grid. The simplest approach is nearest-neighbor lookup: each grid cell gets the value of its closest observation.

Note

This is not interpolation—there’s no averaging or blending. Each grid cell simply takes the value of one source point.

Without a spatial index, finding nearest neighbors requires n_grid × n_points comparisons (2,500 for 50 grid cells and 50 points). With NDPointIndex, each lookup is O(log n).

Let’s assign values to a 10×5 grid from our 50 scattered observations:

# create a regular grid as query points
ds_grid = xr.Dataset(coords={"x": range(10), "y": range(5)})

# selection supports broadcasting of the input labels
ds_selection = ds_index.sel(
    xx=ds_grid.x, yy=ds_grid.y, method="nearest"
)

# assign selection results to the grid
# -> nearest neighbor interpolation
ds_grid["data"] = ds_selection.data.variable

ds_grid
<xarray.Dataset> Size: 520B
Dimensions:  (x: 10, y: 5)
Coordinates:
  * x        (x) int64 80B 0 1 2 3 4 5 6 7 8 9
  * y        (y) int64 40B 0 1 2 3 4
Data variables:
    data     (x, y) float64 400B 20.49 20.49 24.79 19.21 ... 19.09 17.75 18.51

Hide code cell source

# Visualize the result
fig, ax = plt.subplots(figsize=(10, 5))

# Plot the grid
ds_grid.data.plot(x="x", y="y", ax=ax, add_colorbar=True)

# Draw lines connecting each grid cell to its nearest source point
for i, x_val in enumerate(ds_grid.x.values):
    for j, y_val in enumerate(ds_grid.y.values):
        # Find which source point was selected for this grid cell
        nearest_x = ds_selection.xx.sel(x=x_val, y=y_val).values
        nearest_y = ds_selection.yy.sel(x=x_val, y=y_val).values
        # Draw line from grid cell center to source point
        ax.plot([x_val, nearest_x], [y_val, nearest_y],
                color='white', alpha=0.7, lw=1.2, zorder=3)

# Plot the original scattered points on top
ax.scatter(ds.xx.values, ds.yy.values, c='black', s=25, zorder=5,
           edgecolors='white', linewidths=0.5, label='Source points')

ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title('Each grid cell takes the value of its nearest source point')
ax.legend(loc='upper right')
plt.tight_layout()
plt.show()
../_images/af52ead59dd511fc1772d5dabed3af51cadf45a133a82a545c05cb536cf6107e.png

Advanced usage#

A real-world use case: you have ocean model output on a curvilinear grid and want to color a trajectory based on the nearest model values.

This example uses the Regional Ocean Modeling System (ROMS) Xarray example. The data is on a grid of eta_rho × xi_rho (20 × 40 = 800 points), but we want to query by lat_rho and lon_rho.

ds_roms = xr.tutorial.open_dataset("ROMS_example")
ds_roms
<xarray.Dataset> Size: 19MB
Dimensions:     (ocean_time: 2, s_rho: 30, eta_rho: 191, xi_rho: 371)
Coordinates:
  * ocean_time  (ocean_time) datetime64[ns] 16B 2001-08-01 2001-08-08
  * s_rho       (s_rho) float64 240B -0.9833 -0.95 -0.9167 ... -0.05 -0.01667
    Cs_r        (s_rho) float64 240B ...
    lon_rho     (eta_rho, xi_rho) float64 567kB ...
    h           (eta_rho, xi_rho) float64 567kB ...
    lat_rho     (eta_rho, xi_rho) float64 567kB ...
    hc          float64 8B ...
    Vtransform  int32 4B ...
Dimensions without coordinates: eta_rho, xi_rho
Data variables:
    salt        (ocean_time, s_rho, eta_rho, xi_rho) float32 17MB ...
    zeta        (ocean_time, eta_rho, xi_rho) float32 567kB ...
Attributes: (12/34)
    file:              ../output_20yr_obc/2001/ocean_his_0015.nc
    format:            netCDF-4/HDF5 file
    Conventions:       CF-1.4
    type:              ROMS/TOMS history file
    title:             TXLA ROMS hindcast run with dyes and oxygen
    rst_file:          ../output_20yr_obc/2001/ocean_rst.nc
    ...                ...
    compiler_flags:    -heap-arrays -fp-model fast -mt_mpi -ip -O3 -msse2 -free
    tiling:            010x012
    history:           Tue Jul 24 11:04:43 2018: /opt/nco/ncks -D 4 -t 8 /cop...
    ana_file:          /home/d.kobashi/TXLA_ROMS_reana/Functionals/ana_btflux...
    CPP_options:       TXLA2, ANA_BPFLUX, ANA_BSFLUX, ANA_BTFLUX, ANA_NUDGCOE...
    NCO:               netCDF Operators version 4.7.6-alpha04 (Homepage = htt...

The challenge#

We want to color 50 trajectory points based on the nearest of 800 ocean model grid points.

Without a KD-tree: 50 × 800 = 40,000 distance calculations

With NDPointIndex: 50 × ~10 = ~500 calculations (80x faster!)

Hide code cell source

import matplotlib.pyplot as plt

ds_trajectory = xr.Dataset(
    coords={
        "lat": ("trajectory", np.linspace(28, 30, 50)),
        "lon": ("trajectory", np.linspace(-93, -88, 50)),
    },
)

ds_roms.salt.isel(s_rho=-1, ocean_time=0).plot(
    x="lon_rho", y="lat_rho", alpha=0.6
)
plt.plot(
    ds_trajectory.lon.data,
    ds_trajectory.lat.data,
    marker=".",
    color="k",
    ms=4,
    ls="none",
)
plt.show()
../_images/f7997b341c6e46401c5432a051768e8bef237e14a68aa2a270ea493fcd5b26ba.png

Assigning the Index#

First add the index to the lat and lon coord variable using set_xindex

ds_roms_index = ds_roms.set_xindex(
    ("lat_rho", "lon_rho"),
    xr.indexes.NDPointIndex,
)
ds_roms_index
<xarray.Dataset> Size: 19MB
Dimensions:     (ocean_time: 2, s_rho: 30, eta_rho: 191, xi_rho: 371)
Coordinates:
  * ocean_time  (ocean_time) datetime64[ns] 16B 2001-08-01 2001-08-08
  * s_rho       (s_rho) float64 240B -0.9833 -0.95 -0.9167 ... -0.05 -0.01667
    Cs_r        (s_rho) float64 240B ...
    h           (eta_rho, xi_rho) float64 567kB ...
  * lat_rho     (eta_rho, xi_rho) float64 567kB 27.45 27.45 ... 30.85 30.86
  * lon_rho     (eta_rho, xi_rho) float64 567kB -93.6 -93.58 ... -88.88 -88.87
    hc          float64 8B ...
    Vtransform  int32 4B ...
Dimensions without coordinates: eta_rho, xi_rho
Data variables:
    salt        (ocean_time, s_rho, eta_rho, xi_rho) float32 17MB ...
    zeta        (ocean_time, eta_rho, xi_rho) float32 567kB ...
Indexes:
  ┌ lat_rho  NDPointIndex (ScipyKDTreeAdapter)
  └ lon_rho
Attributes: (12/34)
    file:              ../output_20yr_obc/2001/ocean_his_0015.nc
    format:            netCDF-4/HDF5 file
    Conventions:       CF-1.4
    type:              ROMS/TOMS history file
    title:             TXLA ROMS hindcast run with dyes and oxygen
    rst_file:          ../output_20yr_obc/2001/ocean_rst.nc
    ...                ...
    compiler_flags:    -heap-arrays -fp-model fast -mt_mpi -ip -O3 -msse2 -free
    tiling:            010x012
    history:           Tue Jul 24 11:04:43 2018: /opt/nco/ncks -D 4 -t 8 /cop...
    ana_file:          /home/d.kobashi/TXLA_ROMS_reana/Functionals/ana_btflux...
    CPP_options:       TXLA2, ANA_BPFLUX, ANA_BSFLUX, ANA_BTFLUX, ANA_NUDGCOE...
    NCO:               netCDF Operators version 4.7.6-alpha04 (Homepage = htt...

Coloring the trajectory#

Now we can efficiently find the nearest ocean model point for each trajectory point:

ds_roms_selection = ds_roms_index.sel(
    lat_rho=ds_trajectory.lat,
    lon_rho=ds_trajectory.lon,
    method="nearest",
)
ds_roms_selection
<xarray.Dataset> Size: 14kB
Dimensions:     (ocean_time: 2, s_rho: 30, trajectory: 50)
Coordinates:
  * ocean_time  (ocean_time) datetime64[ns] 16B 2001-08-01 2001-08-08
  * s_rho       (s_rho) float64 240B -0.9833 -0.95 -0.9167 ... -0.05 -0.01667
    Cs_r        (s_rho) float64 240B ...
    h           (trajectory) float64 400B ...
    lat_rho     (trajectory) float64 400B 28.01 28.03 28.07 ... 29.96 29.95
    lon_rho     (trajectory) float64 400B -93.0 -92.9 -92.8 ... -88.11 -88.09
    hc          float64 8B ...
    Vtransform  int32 4B ...
Dimensions without coordinates: trajectory
Data variables:
    salt        (ocean_time, s_rho, trajectory) float32 12kB ...
    zeta        (ocean_time, trajectory) float32 400B ...
Attributes: (12/34)
    file:              ../output_20yr_obc/2001/ocean_his_0015.nc
    format:            netCDF-4/HDF5 file
    Conventions:       CF-1.4
    type:              ROMS/TOMS history file
    title:             TXLA ROMS hindcast run with dyes and oxygen
    rst_file:          ../output_20yr_obc/2001/ocean_rst.nc
    ...                ...
    compiler_flags:    -heap-arrays -fp-model fast -mt_mpi -ip -O3 -msse2 -free
    tiling:            010x012
    history:           Tue Jul 24 11:04:43 2018: /opt/nco/ncks -D 4 -t 8 /cop...
    ana_file:          /home/d.kobashi/TXLA_ROMS_reana/Functionals/ana_btflux...
    CPP_options:       TXLA2, ANA_BPFLUX, ANA_BSFLUX, ANA_BTFLUX, ANA_NUDGCOE...
    NCO:               netCDF Operators version 4.7.6-alpha04 (Homepage = htt...

The trajectory points are now colored by the salinity of the nearest ocean model grid point - efficiently computed using the KD-tree under the hood:

Hide code cell source

fig, ax = plt.subplots(figsize=(10, 6))

# Plot ocean model salinity as background
ds_roms.salt.isel(s_rho=-1, ocean_time=0).plot(
    x="lon_rho", y="lat_rho", vmin=0, vmax=35, alpha=0.3, ax=ax, add_colorbar=True
)

# Plot trajectory colored by nearest salinity values
scatter = ax.scatter(
    ds_trajectory.lon.data,
    ds_trajectory.lat.data,
    c=ds_roms_selection.isel(s_rho=-1, ocean_time=0).salt,
    s=15,
    vmin=0,
    vmax=35,
    edgecolors="k",
    linewidths=0.5,
)

ax.set_xlabel('Longitude')
ax.set_ylabel('Latitude')
ax.set_title('Ship trajectory colored by nearest ocean model salinity')
plt.tight_layout()
plt.show()
../_images/d3d5ff9eb51a6c77d816a152e3ed205e876c246c32b7b3b6921de32afa6ffbd0.png

Alternative trees with xoak#

The default KD-tree uses Euclidean distance, which works well for most cases. However, for geographic coordinates (lat/lon), this can give incorrect results at high latitudes because longitude degrees shrink toward the poles.

Tip

See Tree-Based Indexing - The problem with geographic coordinates for a detailed explanation with visualizations.

Using a Ball tree for geographic data#

For accurate great-circle distance calculations, you can plug in a Ball tree with the haversine metric. The xoak package provides ready-made tree adapters for common use cases, including SklearnGeoBallTreeAdapter which wraps scikit-learn’s BallTree with the haversine metric.

Pass it as the tree_adapter_cls argument to set_xindex — everything else stays the same:

from xoak import SklearnGeoBallTreeAdapter

ds_roms_ball_index = ds_roms.set_xindex(
    ("lat_rho", "lon_rho"),
    xr.indexes.NDPointIndex,
    tree_adapter_cls=SklearnGeoBallTreeAdapter,
)
ds_roms_ball_index
<xarray.Dataset> Size: 19MB
Dimensions:     (ocean_time: 2, s_rho: 30, eta_rho: 191, xi_rho: 371)
Coordinates:
  * ocean_time  (ocean_time) datetime64[ns] 16B 2001-08-01 2001-08-08
  * s_rho       (s_rho) float64 240B -0.9833 -0.95 -0.9167 ... -0.05 -0.01667
    Cs_r        (s_rho) float64 240B ...
    h           (eta_rho, xi_rho) float64 567kB ...
  * lat_rho     (eta_rho, xi_rho) float64 567kB 27.45 27.45 ... 30.85 30.86
  * lon_rho     (eta_rho, xi_rho) float64 567kB -93.6 -93.58 ... -88.88 -88.87
    hc          float64 8B ...
    Vtransform  int32 4B ...
Dimensions without coordinates: eta_rho, xi_rho
Data variables:
    salt        (ocean_time, s_rho, eta_rho, xi_rho) float32 17MB ...
    zeta        (ocean_time, eta_rho, xi_rho) float32 567kB ...
Indexes:
  ┌ lat_rho  NDPointIndex (SklearnGeoBallTreeAdapter)
  └ lon_rho
Attributes: (12/34)
    file:              ../output_20yr_obc/2001/ocean_his_0015.nc
    format:            netCDF-4/HDF5 file
    Conventions:       CF-1.4
    type:              ROMS/TOMS history file
    title:             TXLA ROMS hindcast run with dyes and oxygen
    rst_file:          ../output_20yr_obc/2001/ocean_rst.nc
    ...                ...
    compiler_flags:    -heap-arrays -fp-model fast -mt_mpi -ip -O3 -msse2 -free
    tiling:            010x012
    history:           Tue Jul 24 11:04:43 2018: /opt/nco/ncks -D 4 -t 8 /cop...
    ana_file:          /home/d.kobashi/TXLA_ROMS_reana/Functionals/ana_btflux...
    CPP_options:       TXLA2, ANA_BPFLUX, ANA_BSFLUX, ANA_BTFLUX, ANA_NUDGCOE...
    NCO:               netCDF Operators version 4.7.6-alpha04 (Homepage = htt...

The adapter handles converting lat/lon from degrees to radians internally, so you can query with the same degree-valued coordinates as before.

At high latitudes, longitude degrees are much shorter than latitude degrees. A KD-tree and Ball tree can pick different nearest neighbors for the same query:

Hide code cell source

# Two points at 80°N: one east (along longitude), one north (along latitude)
# The query sits between them — closer in *degrees* to the north point,
# but closer in *km* to the east point.
point_a = (80.0, 2.0)   # 2° east of query
point_b = (80.5, 0.0)   # 0.5° north of query
query = (80.0, 0.0)

# Use 2D coords on a dummy grid (NDPointIndex requires shared dims)
ds_demo = xr.Dataset(
    {"label": (("i", "j"), [["A", "B"]])},
    coords={
        "lat": (("i", "j"), [[point_a[0], point_b[0]]]),
        "lon": (("i", "j"), [[point_a[1], point_b[1]]]),
    },
)

# KD-tree (Euclidean on degrees)
ds_kd = ds_demo.set_xindex(("lat", "lon"), xr.indexes.NDPointIndex)
result_kd = ds_kd.sel(lat=query[0], lon=query[1], method="nearest")

# Ball tree (haversine)
ds_ball = ds_demo.set_xindex(
    ("lat", "lon"), xr.indexes.NDPointIndex,
    tree_adapter_cls=SklearnGeoBallTreeAdapter,
)
result_ball = ds_ball.sel(lat=query[0], lon=query[1], method="nearest")

kd_pick = (result_kd.lat.item(), result_kd.lon.item())
ball_pick = (result_ball.lat.item(), result_ball.lon.item())

km_per_deg_lon = 111.0 * np.cos(np.radians(80.0))
km_to_a = abs(query[1] - point_a[1]) * km_per_deg_lon
km_to_b = abs(query[0] - point_b[0]) * 111.0

# Convert to km relative to query
def to_km(lat, lon):
    return (lon - query[1]) * km_per_deg_lon, (lat - query[0]) * 111.0

a_km = to_km(*point_a)
b_km = to_km(*point_b)
q_km = (0.0, 0.0)

fig, (ax_deg, ax_km) = plt.subplots(1, 2, figsize=(13, 5))

for ax, coords, xlabel, ylabel, title, a_label, b_label in [
    (ax_deg,
     {"a": (point_a[1], point_a[0]), "b": (point_b[1], point_b[0]), "q": (query[1], query[0])},
     "Longitude (°)", "Latitude (°)",
     "In degrees: B looks closer",
     f"A (2.0°)", f"B (0.5°)"),
    (ax_km,
     {"a": (a_km[0], a_km[1]), "b": (b_km[0], b_km[1]), "q": q_km},
     "East-West (km)", "North-South (km)",
     "In km: A is actually closer",
     f"A ({km_to_a:.0f} km)", f"B ({km_to_b:.0f} km)"),
]:
    # Line to A (Ball tree / haversine picks this)
    ax.annotate(
        "", xy=coords["a"], xytext=coords["q"],
        arrowprops=dict(arrowstyle="-", color="coral", lw=2, ls="--"), zorder=4,
    )
    # Line to B (KD-tree picks this)
    ax.annotate(
        "", xy=coords["b"], xytext=coords["q"],
        arrowprops=dict(arrowstyle="-", color="steelblue", lw=2, ls="--"), zorder=4,
    )

    # Label the lines
    mid_a = ((coords["q"][0] + coords["a"][0]) / 2, (coords["q"][1] + coords["a"][1]) / 2)
    mid_b = ((coords["q"][0] + coords["b"][0]) / 2, (coords["q"][1] + coords["b"][1]) / 2)
    ax.annotate("haversine", mid_a, xytext=(8, -10), textcoords="offset points",
                fontsize=9, color="coral", fontstyle="italic")
    ax.annotate("KD-tree", mid_b, xytext=(8, 4), textcoords="offset points",
                fontsize=9, color="steelblue", fontstyle="italic")

    # Points
    ax.scatter(*coords["a"], c="coral", s=140, edgecolors="black", linewidths=1.5, zorder=6)
    ax.scatter(*coords["b"], c="steelblue", s=140, edgecolors="black", linewidths=1.5, zorder=6)
    ax.scatter(*coords["q"], c="red", s=180, marker="X",
               edgecolors="darkred", linewidths=0.5, zorder=7)

    # Point labels
    ax.annotate(a_label, coords["a"], xytext=(12, -14), textcoords="offset points",
                fontsize=10, fontweight="bold")
    ax.annotate(b_label, coords["b"], xytext=(12, 8), textcoords="offset points",
                fontsize=10, fontweight="bold")
    ax.annotate("query", coords["q"], xytext=(12, -14), textcoords="offset points",
                fontsize=10, color="red", fontweight="bold")

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title, fontweight="bold")
    ax.grid(True, alpha=0.3)

# Match x and y limits so distances are visually comparable
deg_lim = max(abs(point_a[1] - query[1]), abs(point_b[0] - query[0])) * 1.3
ax_deg.set_xlim(-deg_lim * 0.15, deg_lim)
ax_deg.set_ylim(query[0] - deg_lim * 0.15, query[0] + deg_lim)

km_lim = max(abs(a_km[0]), abs(a_km[1]), abs(b_km[0]), abs(b_km[1])) * 1.3
ax_km.set_xlim(-km_lim * 0.15, km_lim)
ax_km.set_ylim(-km_lim * 0.15, km_lim)

plt.tight_layout()
plt.show()
../_images/630b4cfdf8fa119bb723752da3928a2597497ddaa3d70577b3b1dd04ce3eda0b.png

Performance comparison#

Let’s compare the actual lookup times for our trajectory query:

Hide code cell source

import time

# Get the data points and query points
lat_rho = ds_roms.lat_rho.values.ravel()
lon_rho = ds_roms.lon_rho.values.ravel()
points = np.column_stack([lat_rho, lon_rho])

query_lats = ds_trajectory.lat.values
query_lons = ds_trajectory.lon.values
queries = np.column_stack([query_lats, query_lons])

n_points = len(points)
n_queries = len(queries)

# Time naive search
def naive_nearest(points, queries):
    indices = []
    for q in queries:
        dists = np.sum((points - q)**2, axis=1)
        indices.append(np.argmin(dists))
    return np.array(indices)

start = time.perf_counter()
for _ in range(100):
    naive_nearest(points, queries)
naive_time = (time.perf_counter() - start) / 100 * 1000

# Time KD-tree
start = time.perf_counter()
for _ in range(100):
    ds_roms_index.sel(lat_rho=ds_trajectory.lat, lon_rho=ds_trajectory.lon, method="nearest")
kdtree_time = (time.perf_counter() - start) / 100 * 1000

# Time Ball tree
start = time.perf_counter()
for _ in range(100):
    ds_roms_ball_index.sel(lat_rho=ds_trajectory.lat, lon_rho=ds_trajectory.lon, method="nearest")
balltree_time = (time.perf_counter() - start) / 100 * 1000

# Plot comparison
fig, ax = plt.subplots(figsize=(10, 5))

methods = ['Naive\n(brute force)', 'KD-tree\n(Euclidean)', 'Ball tree\n(haversine)']
times = [naive_time, kdtree_time, balltree_time]
colors = ['gray', 'steelblue', 'coral']

bars = ax.bar(methods, times, color=colors)
ax.set_ylabel('Time per lookup (ms)')
ax.set_title(f'Finding nearest neighbors: {n_queries} trajectory points across {n_points} ocean grid points')

for bar, t in zip(bars, times):
    ax.annotate(f'{t:.2f} ms', (bar.get_x() + bar.get_width()/2, bar.get_height()),
                ha='center', va='bottom', fontsize=11, fontweight='bold')

ax.set_ylim(0, max(times) * 1.2)
plt.tight_layout()
plt.show()

print(f"KD-tree is {naive_time/kdtree_time:.1f}x faster than naive")
print(f"Ball tree is {naive_time/balltree_time:.1f}x faster than naive")
print("\nThe tradeoff: KD-tree is faster but treats lat/lon as flat.")
print("Ball tree is slower but uses haversine for correct great-circle distances.")
../_images/7c5e8ef8c7a5a2dd6ce12443a9745d9c0614ff276ae3143fdb8463b45f47b5a6.png
KD-tree is 55.9x faster than naive
Ball tree is 35.7x faster than naive

The tradeoff: KD-tree is faster but treats lat/lon as flat.
Ball tree is slower but uses haversine for correct great-circle distances.

Choosing a tree#

Use case

Recommended

Why

Cartesian coordinates (x, y in meters)

KD-tree (default)

Fastest option, Euclidean distance is correct

Geographic data near the equator

KD-tree

Euclidean on lat/lon is approximately correct

Geographic data at high latitudes

Ball tree with haversine

Euclidean on lat/lon gives wrong answers

Custom distance metric needed

Ball tree

Supports haversine, Manhattan, and others

Very high dimensions (>20)

Ball tree

KD-trees degrade in high dimensions

See also