Source code for ewoksfluo.tasks.regrid.regrid_stack

from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Sequence
from typing import Union

from ewokscore import Task
from ewokscore.model import BaseInputModel
from ewokscore.model import BaseOutputModel
from pydantic import Field

from ...io import hdf5
from ...io import output_uri
from ...math.regular_grid import ScatterDataInterpolator
from ...math.rounding import round_to_significant
from .. import nexus_utils
from ..positioner_utils import get_primary_beam_energy_suburi
from ..xrf_results import get_xrf_result_groups
from . import positioners_utils


[docs] class Inputs(BaseInputModel): bliss_scan_uris: List[str] = Field( description="Bliss scan URI's.", examples=[["/data/dataset.h5::/1.1", "/data/dataset.h5::/2.1"]], ) output_root_uri: str = Field( description="Target HDF5 file URI with optional data path.", examples=[ "/results/dataset.h5", "/results/dataset.h5::/1.1", "/results/dataset.h5::/1.1/regrid", ], ) output_root_group: Optional[str] = Field( default=None, description="Optional group underneath ``output_root_uri``." ) xrf_results_uri: str = Field( description="Previous XRF NXprocess URI to regrid.", examples=["/results/dataset.h5::/1.1/regrid/results"], ) stack_positioner: Optional[str] = Field(default=None) positioners: Union[List[str], str, None] = Field(default=None) positioner_uri_template: str = Field(default="measurement/{}") ignore_positioners: Optional[List[str]] = Field(default=None) interpolate: Literal["nearest", "linear", "cubic"] = Field(default="nearest") resolution: Optional[dict] = Field(default=None) axes_units: Dict[str, str] = Field( default_factory=dict, description="Axes units to be used when missing.", examples=[{"nsy": "mm", "nsz": "mm", "nspy": "um", "nspz": "um"}], )
[docs] class Outputs(BaseOutputModel): bliss_scan_uris: List[str] = Field( description="Bliss scan URI's.", examples=[["/data/dataset.h5::/1.1", "/data/dataset.h5::/2.1"]], ) output_root_uri: str = Field( description="Original output root URI received as input.", examples=["/results/dataset.h5::/1.1"], ) output_root_group: Optional[str] = Field( default=None, description="Original output root group received as input." ) xrf_results_uri: str = Field( description="Output NXprocess results URI.", examples=["/results/dataset.h5::/1.1/regrid/results"], )
[docs] class RegridXrfResultsStack(Task, input_model=Inputs, output_model=Outputs): """Regrid XRF stack results on a regular grid by interpolation."""
[docs] def run(self): start_time = nexus_utils.now() _, scan_h5path0 = hdf5.split_h5uri(self.inputs.bliss_scan_uris[0]) output_root_uri = output_uri.compose_full_output_uri( self.inputs.output_root_uri, default_output_data_path=scan_h5path0, extra_data_paths=(self.inputs.output_root_group, "regrid"), ) with nexus_utils.save_in_ewoks_subprocess(output_root_uri, start_time, {}) as ( regrid_results, already_existed, ): if not already_existed: self._regrid(regrid_results) self.outputs.xrf_results_uri = ( f"{regrid_results.file.filename}::{regrid_results.name}" ) self.outputs.bliss_scan_uris = self.inputs.bliss_scan_uris self.outputs.output_root_uri = self.inputs.output_root_uri self.outputs.output_root_group = self.inputs.output_root_group
def _regrid(self, regrid_results): bliss_scan_uris = self.inputs.bliss_scan_uris xrf_results_uri = self.inputs.xrf_results_uri ignore_positioners = self.inputs.ignore_positioners position_suburis = self._get_position_suburis( bliss_scan_uris, ignore_positioners=ignore_positioners ) stack_suburi = self._get_stack_suburi(bliss_scan_uris) interpolate = self.inputs.interpolate resolution = self.inputs.resolution axes_units = self.inputs.axes_units xrf_results_filename, xrf_results_h5path = hdf5.split_h5uri(xrf_results_uri) with hdf5.FileReadAccess(xrf_results_filename) as xrf_results_file: xrf_results_grp = xrf_results_file[xrf_results_h5path] if not hdf5.is_group(xrf_results_grp): raise TypeError(f"'{xrf_results_h5path}' must be a group") nxdata_groups = get_xrf_result_groups(xrf_results_grp) coordinates, grid_names, grid_units = ( positioners_utils.read_position_suburis( bliss_scan_uris[0], position_suburis, axes_units=axes_units ) ) if resolution: resolution = [resolution[name] for name in grid_names] else: resolution = None interpolator = ScatterDataInterpolator( coordinates, grid_names, grid_units, method=interpolate, resolution=resolution, ) axes_names = [ stack_suburi.split("/")[-1], *grid_names, ] # Get fit result datasets (inputs) and output dataset information input_datasets = list() output_info = list() output_grps = list() for group_name in reversed(list(nxdata_groups)): input_grp = xrf_results_grp[group_name] output_grp = nexus_utils.create_nxdata(regrid_results, group_name) output_grps.append(output_grp) signals = list() for dset_name, dset in input_grp.items(): if not hdf5.is_dataset(dset) or dset_name in axes_names: continue key = group_name, dset_name input_datasets.append(dset) output_info.append((output_grp, group_name, dset_name)) signals.append(dset_name) nexus_utils.set_nxdata_signals(output_grp, signals=signals) # NXdata signals nscans = len(bliss_scan_uris) output_datasets = dict() stack_axis_data = list() for scan_index, (bliss_scan_uri, *input_data) in enumerate( zip(bliss_scan_uris, *input_datasets) ): stack_axis_data.append( positioners_utils.get_position_data(bliss_scan_uri, stack_suburi) ) for (output_grp, group_name, dset_name), data in zip( output_info, input_data ): data = interpolator.regrid(data) key = group_name, dset_name dset = output_datasets.get(key) if dset is None: stack_shape = (nscans,) + data.shape dset = output_grp.create_dataset( dset_name, shape=stack_shape, dtype=data.dtype ) output_datasets[key] = dset dset[scan_index] = data stack_axis_values, stack_axis_units = zip(*stack_axis_data) stack_axis_units = list(set(stack_axis_units))[0] # NXdata axes axes_data = [stack_axis_values] + interpolator.grid_axes _axes_units = [stack_axis_units] + interpolator.units title_parts = list() for iaxis, (axisname, axisunits, arr) in enumerate( zip(axes_names, _axes_units, axes_data) ): title_parts.append( ( axisname, len(arr), round_to_significant(abs(arr[1] - arr[0])), axisunits, ) ) for output_grp in output_grps: dset = output_grp.create_dataset(axisname, data=arr) if axisunits: dset.attrs["units"] = axisunits dset.attrs["long_name"] = f"{axisname} ({axisunits})" output_grp.create_dataset(f"{axisname}_indices", data=iaxis) output_grp.attrs["axes"] = axes_names interpolator.save_coordinates_as_nxdata(regrid_results) title = [ f"{axisname} ({size} x {resolution} {units})" for axisname, size, resolution, units in title_parts ] regrid_results["title"] = " x ".join(title) or "No axes" def _get_position_suburis( self, bliss_scan_uris: Sequence[str], ignore_positioners: Optional[Sequence[str]] = None, ) -> List[str]: positioners = self.inputs.positioners if not positioners: return positioners_utils.get_scan_position_suburis( bliss_scan_uris[0], ignore_positioners=ignore_positioners ) if isinstance(positioners, str): positioners = [positioners] template = self.inputs.positioner_uri_template return [template.format(s) for s in positioners] def _get_stack_suburi(self, bliss_scan_uris: Sequence[str]) -> str: stack_positioner = self.inputs.stack_positioner if stack_positioner: template = self.inputs.positioner_uri_template return template.format(stack_positioner) suburi = get_primary_beam_energy_suburi(bliss_scan_uris[0]) if not suburi: raise RuntimeError("Cannot find energy positioner") return suburi