Source code for ewoksfluo.tasks.regrid.regrid

from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Sequence
from typing import Union

from ewokscore import Task
from ewokscore.model import BaseInputModel
from ewokscore.model import BaseOutputModel
from pydantic import Field

from ...io import hdf5
from ...io import output_uri
from ...math.regular_grid import ScatterDataInterpolator
from ...math.rounding import round_to_significant
from .. import nexus_utils
from ..xrf_results import get_xrf_result_groups
from . import positioners_utils


[docs] class Inputs(BaseInputModel): bliss_scan_uri: str = Field( description="Bliss scan URI.", examples=["/data/dataset.h5::/1.1"] ) output_root_uri: str = Field( description="Target HDF5 file URI with optional data path.", examples=[ "/results/dataset.h5", "/results/dataset.h5::/1.1", "/results/dataset.h5::/1.1/regrid", ], ) output_root_group: Optional[str] = Field( default=None, description="Optional group underneath ``output_root_uri``." ) xrf_results_uri: str = Field( description="Previous XRF NXprocess URI to regrid.", examples=["/results/dataset.h5::/1.1/norm/results"], ) positioners: Union[List[str], str, None] = Field(default=None) positioner_uri_template: str = Field(default="measurement/{}") ignore_positioners: Optional[List[str]] = Field(default=None) interpolate: Literal["none", "nearest", "linear", "cubic"] = Field( default="nearest" ) resolution: Optional[dict] = Field(default=None) axes_units: Dict[str, str] = Field( default_factory=dict, description="Axes units to be used when missing.", examples=[{"nsy": "mm", "nsz": "mm", "nspy": "um", "nspz": "um"}], )
[docs] class Outputs(BaseOutputModel): bliss_scan_uri: str = Field( description="Bliss scan URI.", examples=["/data/dataset.h5::/1.1"] ) output_root_uri: str = Field( description="Original output root URI received as input.", examples=["/results/dataset.h5::/1.1"], ) output_root_group: Optional[str] = Field( default=None, description="Original output root group received as input." ) xrf_results_uri: str = Field( description="Output NXprocess results URI.", examples=["/results/dataset.h5::/1.1/regrid/results"], )
[docs] class RegridXrfResults(Task, input_model=Inputs, output_model=Outputs): """Regrid single-scan XRF results on a regular grid by interpolation."""
[docs] def run(self): start_time = nexus_utils.now() _, scan_h5path = hdf5.split_h5uri(self.inputs.bliss_scan_uri) output_root_uri = output_uri.compose_full_output_uri( self.inputs.output_root_uri, default_output_data_path=scan_h5path, extra_data_paths=(self.inputs.output_root_group, "regrid"), ) with nexus_utils.save_in_ewoks_subprocess(output_root_uri, start_time, {}) as ( regrid_results, already_existed, ): if not already_existed: self._regrid(regrid_results) self.outputs.xrf_results_uri = ( f"{regrid_results.file.filename}::{regrid_results.name}" ) self.outputs.bliss_scan_uri = self.inputs.bliss_scan_uri self.outputs.output_root_uri = self.inputs.output_root_uri self.outputs.output_root_group = self.inputs.output_root_group
def _regrid(self, regrid_results): position_suburis = self._get_position_suburis( self.inputs.bliss_scan_uri, ignore_positioners=self.inputs.ignore_positioners, ) coordinates, grid_names, grid_units = positioners_utils.read_position_suburis( self.inputs.bliss_scan_uri, position_suburis, axes_units=self.inputs.axes_units, ) resolution = self.inputs.resolution if resolution: resolution = [resolution[name] for name in grid_names] else: resolution = None interpolator = ScatterDataInterpolator( coordinates, grid_names, grid_units, method=self.inputs.interpolate, resolution=resolution, ) xrf_results_filename, xrf_results_h5path = hdf5.split_h5uri( self.inputs.xrf_results_uri ) with hdf5.FileReadAccess(xrf_results_filename) as xrf_results_file: xrf_results_grp = xrf_results_file[xrf_results_h5path] if not hdf5.is_group(xrf_results_grp): raise TypeError(f"'{self.inputs.xrf_results_uri}' must be a group") nxdata_groups = get_xrf_result_groups(xrf_results_grp) for group_name in reversed(list(nxdata_groups)): input_grp = xrf_results_grp[group_name] input_datasets = { dset_name: dset for dset_name, dset in input_grp.items() if hdf5.is_dataset(dset) and dset_name not in grid_names } if not input_datasets: # NXdata group which does not plot scan data continue # NXdata signals output_grp = nexus_utils.create_nxdata(regrid_results, group_name) for dset_name, dset in input_datasets.items(): output_grp.create_dataset( dset_name, data=interpolator.regrid(dset[()]) ) nexus_utils.set_nxdata_signals( output_grp, signals=tuple(input_datasets.keys()) ) # NXdata axes axes = list() for i, (axisname, axisunits, arr) in enumerate( zip(grid_names, interpolator.units, interpolator.grid_axes) ): axes.append(axisname) dset = output_grp.create_dataset(axisname, data=arr) if axisunits: dset.attrs["units"] = axisunits dset.attrs["long_name"] = f"{axisname} ({axisunits})" output_grp.create_dataset(f"{axisname}_indices", data=i) output_grp.attrs["axes"] = axes interpolator.save_coordinates_as_nxdata(regrid_results) title_parts = list() for i, (axisname, axisunits, arr) in enumerate( zip(grid_names, interpolator.units, interpolator.grid_axes) ): if len(arr) > 1: step_size = round_to_significant(abs(arr[1] - arr[0])) else: step_size = 0.0 title_parts.append((axisname, len(arr), step_size, axisunits)) title = [ f"{axisname} ({size} x {resolution} {units})" for axisname, size, resolution, units in title_parts ] regrid_results["title"] = " x ".join(title) or "No axes" def _get_position_suburis( self, bliss_scan_uri: str, ignore_positioners: Optional[Sequence[str]] = None ) -> List[str]: positioners = self.inputs.positioners if not positioners: return positioners_utils.get_scan_position_suburis( bliss_scan_uri, ignore_positioners=ignore_positioners ) if isinstance(positioners, str): positioners = [positioners] template = self.inputs.positioner_uri_template return [template.format(s) for s in positioners]