from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import numpy
from scipy.interpolate import griddata
from .. import units
from ..io import hdf5
from . import distance
from . import grid_utils
from . import pad
from .optimal_grid import optimal_grid_axes
[docs]
class ScatterDataInterpolator:
"""Interpolates scatter data on a regular grid."""
def __init__(
self,
scatter_coordinates: Sequence[numpy.ndarray],
scatter_coordinate_names: Sequence[str],
scatter_coordinate_units: Sequence[Optional[str]],
method: Optional[str] = None,
fill_value=None,
fix_resolution: bool = True,
fix_limits: bool = True,
resolution: Optional[
Union[Sequence[float], Sequence[Tuple[float, str]]]
] = None,
outside_resolution_fraction: float = 2.1,
scale_method: Optional[Literal["range", "std"]] = "range",
reference_units: Optional[Dict[str, str]] = None,
reference_units_fallback: Optional[
Literal["first", "largest", "smallest"]
] = "smallest",
):
"""
:param scatter_coordinates: independent variables with shape `(Ndim, Nscatter)`.
:param scatter_coordinate_names: independent variable names with shape `(Ndim,)`
:param scatter_coordinate_units: independent variable units with shape `(Ndim,)`
:param method: interpolation method.
:param fill_value: value used for extrapolation.
:param fix_resolution: fix the resolution.
:param fix_limits: fix the limits.
:param resolution: grid spacing in each dimension.
:param outside_resolution_fraction: a grid point is concidered to be outside the space spanned by the scatter points
when the closest scatter point is farther away than a fraction of the smallest resolution
across all dimensions.
:param scale_method: scale coordinates for interpolation and distance calculation.
:param reference_units: mapping dimensionality -> unit.
:param reference_units_fallback: strategy for unspecified dimensionalities.
"""
if method is None:
method = "nearest"
self._method = method
if fill_value is None:
fill_value = numpy.nan
self._fill_value = fill_value
self._axes_names = list(scatter_coordinate_names)
# Unit conversion of coordinates and resolution
if reference_units or reference_units_fallback:
scatter_coordinates, scatter_coordinate_units = (
units.convert_units_to_group_reference(
scatter_coordinates,
scatter_coordinate_units,
reference_units=reference_units,
fallback=reference_units_fallback,
)
)
else:
scatter_coordinate_units = units.normalize_units(scatter_coordinate_units)
assert len(scatter_coordinate_names) == len(scatter_coordinates)
assert len(scatter_coordinate_units) == len(scatter_coordinates)
resolution = units.convert_values_to_units(resolution, scatter_coordinate_units)
self._units = scatter_coordinate_units
# Trim coordinate length
nscatter = set()
for i, coordi in enumerate(scatter_coordinates):
if coordi.ndim != 1:
raise ValueError(f"Coordinate {i} must be provided as a 1D array")
if coordi.size == 0:
raise ValueError(f"Coordinate {i} is empty")
nscatter.add(len(coordi))
self._nscatter = min(nscatter)
if len(nscatter) > 1:
scatter_coordinates = [
coordi[: self._nscatter] for coordi in scatter_coordinates
]
# Shape (Nscatter, Ndim)
self._scatter_coordinates = numpy.vstack(scatter_coordinates).T
# Define interpolation grid
self._grid_axes = optimal_grid_axes(
self._scatter_coordinates,
fix_resolution=fix_resolution,
fix_limits=fix_limits,
resolution=resolution,
)
self._grid_shape = tuple(len(arr) for arr in self._grid_axes)
self._expanded_grid_coordinates = grid_utils.expanded_grid_coordinates(
self._grid_axes
)
# Scaled interpolation grid
scale, offset = distance.compute_axis_transform(
self._scatter_coordinates, method=scale_method
)
self._normalized_scatter_coordinates = distance.normalize_coordinates(
self._scatter_coordinates, scale, offset
)
self._normalized_grid_coordinates = distance.normalize_coordinates(
self._expanded_grid_coordinates, scale, offset
)
# Distances between closest scatter and grid coordinate
self._closest_scatter_index, self._closest_scatter_distance = (
grid_utils.closest_point(
self._normalized_grid_coordinates,
self._normalized_scatter_coordinates,
)
)
self._closest_grid_index, self._closest_grid_distance = (
grid_utils.closest_point(
self._normalized_scatter_coordinates,
self._normalized_grid_coordinates,
)
)
# Identify outliers
per_axis_resolution = numpy.array(
[
abs(axis[1] - axis[0]) if len(axis) > 1 else 0.0
for axis in self._grid_axes
]
)
per_axis_scaled_threshold = distance.normalize_coordinates(
per_axis_resolution * outside_resolution_fraction, scale, None
)
scaled_distance_max = numpy.sqrt(numpy.sum(per_axis_scaled_threshold**2))
self._grid_coordinates_outside = (
self._closest_scatter_distance > scaled_distance_max
)
@property
def scan_size(self) -> int:
return self._nscatter
@property
def grid_size(self) -> int:
return self._expanded_grid_coordinates.shape[0]
@property
def grid_shape(self) -> Tuple[int, ...]:
"""
:returns: tuple `(N0, N1, ...)` with length: `Ndim`
"""
return self._grid_shape
@property
def grid_ndim(self) -> int:
return len(self._grid_shape)
@property
def axes_names(self) -> List[str]:
return self._axes_names
@property
def units(self) -> List[str]:
return self._units
@property
def grid_axes(self) -> List[numpy.ndarray]:
"""
:returns: `Ndim` arrays with shapes `[(N0,), (N1,), ...]`
"""
return self._grid_axes
@property
def expanded_grid_coordinates(self) -> numpy.ndarray:
"""
:returns: shape `(N0*N1*..., Ndim)`
"""
return self._expanded_grid_coordinates
@property
def grid_coordinates_outside(self) -> numpy.ndarray:
"""Mask for grid coordinates with shape `(N0*N1*..., )` that are outside the scatter cloud.
:returns: shape `(N0*N1*...,)`
"""
return self._grid_coordinates_outside
@property
def scatter_coordinates(self) -> numpy.ndarray:
"""
:returns: shape `(Nscatter, Ndim)`
"""
return self._scatter_coordinates
[docs]
def regrid(self, scatter_data: numpy.ndarray) -> numpy.ndarray:
"""
:param scatter_data: flat list of data values `(Nscatter, M0, M1, ...)` with
`(M0, M1, ...)` the shape of one detector data value
:returns: regridded data values `(N0, N1, ..., M0, M1, ...)`
"""
if scatter_data.ndim == 0:
raise ValueError("Data cannot be a scalar")
if scatter_data.size == 0:
raise ValueError("Data is empty")
ndata = len(scatter_data)
if ndata > self._nscatter:
scatter_data = scatter_data[: self._nscatter]
elif ndata < self._nscatter:
scatter_data = pad.pad_array(scatter_data, self._nscatter)
if scatter_data.ndim == 1:
return self._interpolate_0d_detector(scatter_data)
return self._interpolate_nd_detector(scatter_data)
def _interpolate_0d_detector(self, scatter_data: numpy.ndarray) -> numpy.ndarray:
interpolated_data = _interpolate(
scatter_coordinates=self._normalized_scatter_coordinates,
scatter_data=scatter_data,
interp_coordinates=self._normalized_grid_coordinates,
interp_coordinates_outside=self._grid_coordinates_outside,
method=self._method,
fill_value=self._fill_value,
)
return interpolated_data.reshape(*self._grid_shape)
def _interpolate_nd_detector(self, scatter_data: numpy.ndarray) -> numpy.ndarray:
nscatter = scatter_data.shape[0]
data_shape = scatter_data.shape[1:]
scatter_data_flat_detector = scatter_data.reshape((nscatter, -1))
nscatter, data_size = scatter_data_flat_detector.shape
interpolated_data = numpy.array(
[
_interpolate(
scatter_coordinates=self._normalized_scatter_coordinates,
scatter_data=scatter_data_flat_detector[:, i],
interp_coordinates=self._normalized_grid_coordinates,
interp_coordinates_outside=self._grid_coordinates_outside,
method=self._method,
fill_value=self._fill_value,
)
for i in range(data_size)
]
).T
return interpolated_data.reshape(*self._grid_shape, *data_shape)
def _plot_coordinates(self) -> None:
import matplotlib.pyplot as plt
outside = self.grid_coordinates_outside
inside = ~outside
plt.plot(
self.scatter_coordinates[:, 1],
self.scatter_coordinates[:, 0],
".",
color="blue",
)
plt.plot(
self.expanded_grid_coordinates[inside, 1],
self.expanded_grid_coordinates[inside, 0],
"o",
mfc="none",
color="green",
)
plt.plot(
self.expanded_grid_coordinates[outside, 1],
self.expanded_grid_coordinates[outside, 0],
"o",
mfc="none",
color="red",
)
plt.xlabel(self.axes_names[1])
plt.ylabel(self.axes_names[0])
plt.axis("equal")
plt.show()
[docs]
def save_coordinates_as_nxdata(self, parent: hdf5.GroupType) -> None:
nxcollection = parent.create_group("coordinates")
nxcollection.attrs["NX_class"] = "NXcollection"
nxdata = nxcollection.create_group("scatter_coordinates")
nxdata.attrs["NX_class"] = "NXdata"
nxdata["title"] = "Scatter coordinate indices"
self._save_nxdata_scatter_axes(nxdata)
signal = numpy.arange(len(self.scatter_coordinates))
self._save_nxdata_signal(nxdata, "indices", signal, units=False)
nxdata = nxcollection.create_group("grid_coordinates")
nxdata.attrs["NX_class"] = "NXdata"
nxdata["title"] = "Grid nodes that have data"
self._save_nxdata_grid_axes(nxdata)
signal = ~self.grid_coordinates_outside
self._save_nxdata_signal(nxdata, "has_data", signal, units=False)
nxdata = nxcollection.create_group("coordinates")
nxdata.attrs["NX_class"] = "NXdata"
nxdata["title"] = "Grid nodes (0 or 1) and Scatter coordinates (2)"
self._save_nxdata_scatter_and_grid_axes(nxdata)
signal = numpy.concatenate(
(
numpy.full((len(self.scatter_coordinates),), 2),
~self.grid_coordinates_outside,
)
)
self._save_nxdata_signal(nxdata, "coordinate_type", signal, units=False)
nxdata = nxcollection.create_group("closest_grid_distance")
nxdata.attrs["NX_class"] = "NXdata"
nxdata["title"] = "Closest grid node distance for every scatter coordinate"
self._save_nxdata_scatter_axes(nxdata)
signal = self._closest_grid_distance
self._save_nxdata_signal(nxdata, "distance", signal, units=True)
nxdata = nxcollection.create_group("closest_scatter_distance")
nxdata.attrs["NX_class"] = "NXdata"
nxdata["title"] = "Closest scatter coordinate distance for every grid node"
self._save_nxdata_grid_axes(nxdata)
signal = self._closest_scatter_distance.copy()
signal[self._grid_coordinates_outside] = numpy.nan
self._save_nxdata_signal(nxdata, "distance", signal, units=True)
def _save_nxdata_signal(
self, nxdata: hdf5.GroupType, name: str, data: numpy.ndarray, units: bool = True
) -> None:
nxdata.attrs["signal"] = name
dset = nxdata.create_dataset(name=name, data=data)
if units:
self._save_nxdata_units(dset, name, None)
def _save_nxdata_scatter_axes(self, nxdata: hdf5.GroupType) -> None:
nxdata.attrs["axes"] = self.axes_names[::-1]
for i, (name, values) in enumerate(
zip(self.axes_names, self.scatter_coordinates.T)
):
dset = nxdata.create_dataset(name=name, data=values)
self._save_nxdata_units(dset, name, i)
def _save_nxdata_grid_axes(self, nxdata: hdf5.GroupType) -> None:
nxdata.attrs["axes"] = self.axes_names[::-1]
for i, (name, values) in enumerate(
zip(self.axes_names, self.expanded_grid_coordinates.T)
):
dset = nxdata.create_dataset(name=name, data=values)
self._save_nxdata_units(dset, name, i)
def _save_nxdata_scatter_and_grid_axes(self, nxdata: hdf5.GroupType) -> None:
nxdata.attrs["axes"] = self.axes_names[::-1]
for i, (name, scatter_values, grid_values) in enumerate(
zip(
self.axes_names,
self.scatter_coordinates.T,
self.expanded_grid_coordinates.T,
)
):
values = numpy.concatenate((scatter_values, grid_values))
dset = nxdata.create_dataset(name=name, data=values)
self._save_nxdata_units(dset, name, i)
def _save_nxdata_units(
self,
dset: hdf5.DatasetType,
name: str,
axis_index: Optional[int],
) -> None:
if axis_index is None:
return
unit = self._units[axis_index]
if unit is not None and unit != "dimensionless":
dset.attrs["units"] = unit
dset.attrs["long_name"] = f"{name} ({unit})"
def _interpolate(
scatter_coordinates: numpy.ndarray,
scatter_data: numpy.ndarray,
interp_coordinates: numpy.ndarray,
interp_coordinates_outside: numpy.ndarray,
method: str,
fill_value: float,
) -> numpy.ndarray:
"""
:param scatter_coordinates: scatter coordinates with shape `(Nscatter, Ndim)`
:param scatter_data: flat list of scatter data values `(Nscatter,)`
:param interp_coordinates: interpolate coordinates with shape `(Ninterp, Ndim)`
:param interp_coordinates_outside: boolean array with shape `(Ninterp,)`
:param method:
:param fill_value:
:returns: interpolated data values `(Ninterp,)`
"""
# Remove degenerate dimensions
coordinate_range = numpy.ptp(scatter_coordinates, axis=0)
eps = 1e3 * numpy.finfo(scatter_coordinates.dtype).eps
non_degenerate = coordinate_range > eps
scatter_coordinates = scatter_coordinates[:, non_degenerate]
interp_coordinates = interp_coordinates[:, non_degenerate]
if scatter_coordinates.shape[1] == 0:
# Fully degenerate case
interp_data = numpy.full(
len(interp_coordinates),
scatter_data[0],
dtype=scatter_data.dtype,
)
else:
interp_data = griddata(
scatter_coordinates,
scatter_data,
interp_coordinates,
method=method,
fill_value=fill_value,
)
if numpy.isnan(fill_value) and numpy.issubdtype(interp_data.dtype, numpy.integer):
fill_value = 0
interp_data[interp_coordinates_outside] = fill_value
return interp_data