Source code for ewoksfluo.tasks.regrid_data.regrid_stack

from typing import Sequence, List

import h5py
from ewokscore import Task

from . import utils
from .. import nexus
from ..utils import get_nxdata_groups
from ..positioner_utils import get_energy_suburi


[docs] class RegridStackOnCoordinates( Task, input_names=["scan_uris", "fit_results_uri", "output_uri"], optional_input_names=[ "stack_positioner", "positioners", "positioner_uri_template", "interpolate", ], output_names=["regridded_uri"], ):
[docs] def run(self): start_time = nexus.now() scan_uris: Sequence[str] = self.inputs.scan_uris fit_results_uri: str = self.inputs.fit_results_uri output_uri: str = self.inputs.output_uri position_suburis = self._get_position_suburis(scan_uris) stack_suburi = self._get_stack_suburi(scan_uris, position_suburis) interpolate: str = self.get_input_value("interpolate", "linear") grid = utils.ScanGrid(method=interpolate) fit_results_filename, fit_results_h5path = nexus.split_h5uri(fit_results_uri) with nexus.save_in_ewoks_subprocess( output_uri, start_time, {}, default_levels=("results", "regrid") ) as regrid_results: with h5py.File(fit_results_filename, "r") as fit_results_file: fit_results_grp = fit_results_file[fit_results_h5path] if not isinstance(fit_results_grp, h5py.Group): raise TypeError(f"'{fit_results_h5path}' must be a group") nxdata_groups = get_nxdata_groups(fit_results_grp) # Get fit result datasets (inputs) and output dataset information input_datasets = list() output_info = list() output_grps = list() for group_name in reversed(list(nxdata_groups)): input_grp = fit_results_grp[group_name] output_grp = nexus.create_data(regrid_results, group_name) output_grps.append(output_grp) signals = list() for dset_name, dset in input_grp.items(): if not isinstance(dset, h5py.Dataset): continue key = group_name, dset_name input_datasets.append(dset) output_info.append((output_grp, group_name, dset_name)) signals.append(dset_name) nexus.set_data_signals(output_grp, signals=signals) # NXdata signals nscans = len(scan_uris) output_datasets = dict() stack_axis = list() for scan_index, (scan_uri, *input_data) in enumerate( zip(scan_uris, *input_datasets) ): stack_axis.append(utils.get_position_data(scan_uri, stack_suburi)) interpolator = grid.get_interpolator(scan_uri, position_suburis) for (output_grp, group_name, dset_name), data in zip( output_info, input_data ): data = interpolator(data) key = group_name, dset_name dset = output_datasets.get(key) if dset is None: stack_shape = (nscans,) + data.shape dset = output_grp.create_dataset( dset_name, shape=stack_shape, dtype=data.dtype ) output_datasets[key] = dset dset[scan_index] = data # NXdata axes axes_data = [stack_axis] + grid.mesh_coordinates axes_names = ( stack_suburi.split("/")[-1], ) + grid.names # Order: slow to fast axis for output_grp in output_grps: for i, (axisname, arr) in enumerate(zip(axes_names, axes_data)): output_grp.create_dataset(axisname, data=arr) output_grp.create_dataset(f"{axisname}_indices", data=i) output_grp.attrs["axes"] = axes_names # Order: not relevant self.outputs.regridded_uri = f"{output_uri}/results"
def _get_position_suburis(self, scan_uris: Sequence[str]) -> List[str]: positioners = self.get_input_value("positioners", None) if positioners: if isinstance(positioners, str): positioners = [positioners] template = self._get_positioner_uri_template() return [template.format(s) for s in positioners] return utils.get_scan_position_suburis(scan_uris[0]) def _get_stack_suburi( self, scan_uris: Sequence[str], position_suburis: List[str] ) -> str: stack_positioner = self.get_input_value("stack_positioner", None) if stack_positioner: template = self._get_positioner_uri_template() return template.format(stack_positioner) return get_energy_suburi(scan_uris[0]) def _get_positioner_uri_template(self) -> str: return self.get_input_value("positioner_uri_template", "measurement/{}")