import h5py
from typing import Sequence, Tuple, List, Optional, Union
import numpy
from scipy.interpolate import griddata
from . import grid_utils
from ..units import unit_registry
from .optimal_grid import optimal_grid_axes
[docs]
class ScatterDataInterpolator:
"""Interpolates scatter data on a regular grid."""
def __init__(
self,
scatter_coordinates: Sequence[numpy.ndarray],
scatter_coordinate_names: Sequence[str],
scatter_coordinate_units: Sequence[Optional[str]],
method: Optional[str] = None,
fill_value=None,
fix_resolution: bool = True,
fix_limits: bool = True,
resolution: Optional[
Union[Sequence[float], Sequence[Tuple[float, str]]]
] = None,
outside_resolution_fraction: float = 2.1,
):
"""
:param scatter_coordinates: independent variables with shape `(Ndim, Nscatter)`.
:param scatter_coordinate_names: independent variable names with shape `(Ndim,)`
:param scatter_coordinate_units: independent variable units with shape `(Ndim,)`
:param method: interpolation method.
:param fill_value: value used for extrapolation.
:param fix_resolution: fix the resolution.
:param fix_limits: fix the limits.
:param resolution: grid spacing in each dimension.
:param outside_resolution_fraction: a grid point is concidered to be outside the space spanned by the scatter points
when the closest scatter point is farther away than a fraction of the smallest resolution
across all dimensions.
"""
if method is None:
method = "nearest"
self._method = method
if fill_value is None:
fill_value = numpy.nan
self._fill_value = fill_value
self._axes_names = scatter_coordinate_names
# Units
self._units = scatter_coordinate_units[0]
ureg = unit_registry()
if len(set(scatter_coordinate_units)) > 1:
# All dimensions will have the same units (self._units)
assert not any(units is None for units in scatter_coordinate_units)
if self._units:
scatter_coordinates = [
(v * ureg.parse_units(u)).to(self._units).magnitude
for v, u in zip(scatter_coordinates, scatter_coordinate_units)
]
if resolution and self._units:
# Resolution must have the same units as scatter_coordinates
resolution_correct_units = list()
for value in resolution:
if isinstance(value, tuple):
v, u = value
value = (v * ureg.parse_units(u)).to(self._units).magnitude
resolution_correct_units.append(value)
else:
resolution_correct_units.append(value)
resolution = resolution_correct_units
# Scatter coordinates
assert len(scatter_coordinate_names) == len(scatter_coordinates)
assert len(scatter_coordinate_units) == len(scatter_coordinates)
nscatter = set()
for i, coordi in enumerate(scatter_coordinates):
if coordi.ndim != 1:
raise ValueError(f"Coordinate {i} must be provided as a 1D array")
if coordi.size == 0:
raise ValueError(f"Coordinate {i} is empty")
nscatter.add(len(coordi))
self._nscatter = min(nscatter)
if len(nscatter) > 1:
scatter_coordinates = [
coordi[: self._nscatter] for coordi in scatter_coordinates
]
self._scatter_coordinates = numpy.vstack(
scatter_coordinates
).T # (Nscatter, Ndim)
# Grid coordinates
self._grid_axes = optimal_grid_axes(
self._scatter_coordinates,
fix_resolution=fix_resolution,
fix_limits=fix_limits,
resolution=resolution,
) # [(N0,), (N1,), ...]
self._grid_shape = tuple(len(arr) for arr in self._grid_axes) # (N0, N1, ...)
self._expanded_grid_coordinates = grid_utils.expanded_grid_coordinates(
self._grid_axes
) # (N0*N1*..., Ndim)
# Difference between scatter and grid coordinates
self._closest_scatter_index, self._closest_scatter_distance = (
grid_utils.closest_point(
self._expanded_grid_coordinates, self._scatter_coordinates
)
)
self._closest_grid_index, self._closest_grid_distance = (
grid_utils.closest_point(
self._scatter_coordinates, self._expanded_grid_coordinates
)
)
distance_max = min(
abs(axis[1] - axis[0]) * outside_resolution_fraction
for axis in self._grid_axes
)
self._grid_coordinates_outside = self._closest_scatter_distance > distance_max
# self._plot_coordinates()
@property
def scan_size(self) -> int:
return self._nscatter
@property
def grid_size(self) -> int:
return self._expanded_grid_coordinates[0]
@property
def grid_shape(self) -> Tuple[int]:
"""
:returns: tuple `(N0, N1, ...)` with length: `Ndim`
"""
return self._grid_shape
@property
def grid_ndim(self) -> int:
return len(self._grid_shape)
@property
def axes_names(self) -> List[str]:
return self._axes_names
@property
def units(self) -> Optional[str]:
return self._units
@property
def grid_axes(self) -> List[numpy.ndarray]:
"""
:returns: `Ndim` arrays with shapes `[(N0,), (N1,), ...]`
"""
return self._grid_axes
@property
def expanded_grid_coordinates(self) -> numpy.ndarray:
"""
:returns: shape `(N0*N1*..., Ndim)`
"""
return self._expanded_grid_coordinates
@property
def grid_coordinates_outside(self) -> numpy.ndarray:
"""Mask for grid coordinates with shape `(N0*N1*..., )` that are outside the scatter cloud.
:returns: shape `(N0*N1*...,)`
"""
return self._grid_coordinates_outside
@property
def scatter_coordinates(self) -> numpy.ndarray:
"""
:returns: shape `(Nscatter, Ndim)`
"""
return self._scatter_coordinates
[docs]
def regrid(self, scatter_data: numpy.ndarray) -> numpy.ndarray:
"""
:param scatter_data: flat list of data values `(Nscatter, M0, M1, ...)` with
`(M0, M1, ...)` the shape of one detector data value
:returns: regridded data values `(N0, N1, ..., M0, M1, ...)`
"""
if scatter_data.ndim == 0:
raise ValueError("Data cannot be a scalar")
if scatter_data.size == 0:
raise ValueError("Data is empty")
if len(scatter_data) > self._nscatter:
scatter_data = scatter_data[: self._nscatter]
elif len(scatter_data) < self._nscatter:
nextra = self._nscatter - len(scatter_data)
scatter_data = numpy.pad(
scatter_data,
[(0, nextra)] + [(0, 0)] * (scatter_data.ndim - 1),
constant_values=numpy.nan,
)
if scatter_data.ndim == 1:
return self._interpolate_0d_detector(scatter_data)
return self._interpolate_nd_detector(scatter_data)
def _interpolate_0d_detector(self, scatter_data: numpy.ndarray) -> numpy.ndarray:
interpolated_data = _interpolate(
scatter_coordinates=self._scatter_coordinates,
scatter_data=scatter_data,
interp_coordinates=self._expanded_grid_coordinates,
interp_coordinates_outside=self._grid_coordinates_outside,
method=self._method,
fill_value=self._fill_value,
)
return interpolated_data.reshape(*self._grid_shape)
def _interpolate_nd_detector(self, scatter_data: numpy.ndarray) -> numpy.ndarray:
nscatter = scatter_data.shape[0]
data_shape = scatter_data.shape[1:]
scatter_data_flat_detector = scatter_data.reshape((nscatter, -1))
nscatter, data_size = scatter_data_flat_detector.shape
interpolated_data = numpy.array(
[
_interpolate(
scatter_coordinates=self._scatter_coordinates,
scatter_data=scatter_data_flat_detector[:, i],
interp_coordinates=self._expanded_grid_coordinates,
interp_coordinates_outside=self._grid_coordinates_outside,
method=self._method,
fill_value=self._fill_value,
)
for i in range(data_size)
]
).T
return interpolated_data.reshape(*self._grid_shape, *data_shape)
def _plot_coordinates(self) -> None:
import matplotlib.pyplot as plt
outside = self.grid_coordinates_outside
inside = ~outside
plt.plot(
self.scatter_coordinates[:, 1],
self.scatter_coordinates[:, 0],
".",
color="blue",
)
plt.plot(
self.expanded_grid_coordinates[inside, 1],
self.expanded_grid_coordinates[inside, 0],
"o",
mfc="none",
color="green",
)
plt.plot(
self.expanded_grid_coordinates[outside, 1],
self.expanded_grid_coordinates[outside, 0],
"o",
mfc="none",
color="red",
)
plt.xlabel(self.axes_names[1])
plt.ylabel(self.axes_names[0])
plt.axis("equal")
plt.show()
[docs]
def save_coordinates_as_nxdata(self, parent: h5py.Group) -> None:
nxcollection = parent.create_group("coordinates")
nxcollection.attrs["NX_class"] = "NXcollection"
nxdata = nxcollection.create_group("scatter_coordinates")
nxdata.attrs["NX_class"] = "NXdata"
nxdata["title"] = "Scatter coordinate indices"
self._save_nxdata_scatter_axes(nxdata)
signal = numpy.arange(len(self.scatter_coordinates))
self._save_nxdata_signal(nxdata, "indices", signal, units=False)
nxdata = nxcollection.create_group("grid_coordinates")
nxdata.attrs["NX_class"] = "NXdata"
nxdata["title"] = "Grid nodes that have data"
self._save_nxdata_grid_axes(nxdata)
signal = ~self.grid_coordinates_outside
self._save_nxdata_signal(nxdata, "has_data", signal, units=False)
nxdata = nxcollection.create_group("coordinates")
nxdata.attrs["NX_class"] = "NXdata"
nxdata["title"] = "Grid nodes (0 or 1) and Scatter coordinates (2)"
self._save_nxdata_scatter_and_grid_axes(nxdata)
signal = numpy.concatenate(
(
numpy.full((len(self.scatter_coordinates),), 2),
~self.grid_coordinates_outside,
)
)
self._save_nxdata_signal(nxdata, "coordinate_type", signal, units=False)
nxdata = nxcollection.create_group("closest_grid_distance")
nxdata.attrs["NX_class"] = "NXdata"
nxdata["title"] = "Closest grid node distance for every scatter coordinate"
self._save_nxdata_scatter_axes(nxdata)
signal = self._closest_grid_distance
self._save_nxdata_signal(nxdata, "distance", signal, units=True)
nxdata = nxcollection.create_group("closest_scatter_distance")
nxdata.attrs["NX_class"] = "NXdata"
nxdata["title"] = "Closest scatter coordinate distance for every grid node"
self._save_nxdata_grid_axes(nxdata)
signal = self._closest_scatter_distance.copy()
signal[self._grid_coordinates_outside] = numpy.nan
self._save_nxdata_signal(nxdata, "distance", signal, units=True)
def _save_nxdata_signal(
self, nxdata: h5py.Group, name: str, data: numpy.ndarray, units: bool = True
) -> None:
nxdata.attrs["signal"] = name
dset = nxdata.create_dataset(name=name, data=data)
if units:
self._save_nxdata_units(dset, name)
def _save_nxdata_scatter_axes(self, nxdata: h5py.Group) -> None:
nxdata.attrs["axes"] = self.axes_names[::-1]
for name, values in zip(self.axes_names, self.scatter_coordinates.T):
dset = nxdata.create_dataset(name=name, data=values)
self._save_nxdata_units(dset, name)
def _save_nxdata_grid_axes(self, nxdata: h5py.Group) -> None:
nxdata.attrs["axes"] = self.axes_names[::-1]
for name, values in zip(self.axes_names, self.expanded_grid_coordinates.T):
dset = nxdata.create_dataset(name=name, data=values)
self._save_nxdata_units(dset, name)
def _save_nxdata_scatter_and_grid_axes(self, nxdata: h5py.Group) -> None:
nxdata.attrs["axes"] = self.axes_names[::-1]
for name, scatter_values, grid_values in zip(
self.axes_names,
self.scatter_coordinates.T,
self.expanded_grid_coordinates.T,
):
values = numpy.concatenate((scatter_values, grid_values))
dset = nxdata.create_dataset(name=name, data=values)
self._save_nxdata_units(dset, name)
def _save_nxdata_units(self, dset: h5py.Dataset, name: str) -> None:
if self.units is not None:
dset.attrs["units"] = self.units
dset.attrs["long_name"] = f"{name} ({self.units})"
def _interpolate(
scatter_coordinates: numpy.ndarray,
scatter_data: numpy.ndarray,
interp_coordinates: numpy.ndarray,
interp_coordinates_outside: numpy.ndarray,
method: str,
fill_value: float,
) -> numpy.ndarray:
"""
:param scatter_coordinates: scatter coordinates with shape `(Nscatter, Ndim)`
:param scatter_data: flat list of scatter data values `(Nscatter,)`
:param interp_coordinates: interpolate coordinates with shape `(Ninterp, Ndim)`
:param interp_coordinates_outside: boolean array with shape `(Ninterp,)`
:param method:
:param fill_value:
:returns: interpolated data values `(Ninterp,)`
"""
interp_data = griddata(
scatter_coordinates,
scatter_data,
interp_coordinates,
method=method,
fill_value=fill_value,
)
if numpy.isnan(fill_value) and numpy.issubdtype(interp_data.dtype, numpy.integer):
fill_value = 0
interp_data[interp_coordinates_outside] = fill_value
return interp_data