Source code for ewoksfluo.tasks.example_data.scan_data

import os
import functools
from typing import Iterator, Tuple, Optional, List, Sequence, Dict, NamedTuple

import h5py
import numpy
from silx.io import h5py_utils
from PyMca5.PyMcaIO.ConfigDict import ConfigDict

try:
    from imageio.v3 import imread
except ImportError:
    try:
        from imageio.v2 import imread
    except ImportError:
        from imageio import imread

from scipy import ndimage

from . import xrf_spectra
from .monitor import monitor_signal
from .deadtime import apply_dualchannel_signal_processing
from .. import nexus_utils


[docs] def save_2d_xrf_scans( filename: str, emission_line_groups: List[str], first_scan_number: int, shape: Tuple[int, int], mosaic: Tuple[int, int], energy: float = 12, flux: float = 1e7, expo_time: float = 0.1, counting_noise: bool = True, integral_type: bool = True, rois: Sequence = tuple(), nmcas: int = 1, max_deviation: float = 0, seed: Optional[int] = None, ) -> List[int]: """Simulates and saves 2D XRF scans to a NeXus file. :param filename: Output file path. :param emission_line_groups: List of emission line groups as "Element-Group". :param first_scan_number: Scan number of the first scan in the output file. :param shape: Scan shape. :param mosaic: Split shape in this number of blocks. :param energy: Incident X-ray energy in keV. :param flux: Incident X-ray flux in photons per second. :param expo_time: Exposure time per point in seconds. :param counting_noise: If True, adds Poisson noise to data. :param integral_type: If True, data represents integrated intensities. :param rois: Regions of interest as energy ranges. :param ndetectors: Number of detectors. :param max_deviation: Maximum deviation from a perfect grid as a fraction of the scan step size. :param seed: Random seed for reproducibility. :returns: Scan numbers as many as there are blocks in the mosaic. :raises ValueError: If `shape` is not 2D. """ max_deviation = max(max_deviation, 0) rstate = numpy.random.RandomState(seed=seed) I0_max = int(flux * expo_time) emission_line_groups = [s.split("-") for s in emission_line_groups] specs = _amesh_specs( shape=shape, mosaic=mosaic, max_deviation=max_deviation, expo_time=expo_time ) scan_numbers = list() for scan_number, amesh in enumerate(specs, start=first_scan_number): scan_numbers.append(scan_number) coordinates = _amesh_positions(amesh, rstate, max_deviation) # decaying I0 image I0 = (I0_max * monitor_signal(expo_time, amesh.size)).reshape( amesh.shape, order="F" ) # random images with values between 0 and 1 fluoI0fractions = list( _iter_amesh_data( amesh, coordinates, rstate, nmaps=len(emission_line_groups) ) ) scatterI0fractions = list(_iter_amesh_data(amesh, coordinates, rstate, nmaps=2)) # Peak area counts within expo_time seconds linegroups = [ xrf_spectra.EmissionLineGroup( element, group, (I0 * I0fraction).astype(numpy.uint32) ) for I0fraction, (element, group) in zip( fluoI0fractions, emission_line_groups ) ] scattergroups = [ xrf_spectra.ScatterLineGroup( "Compton000", (I0 * scatterI0fractions[0]).astype(numpy.uint32) ), xrf_spectra.ScatterLineGroup( "Peak000", (I0 * scatterI0fractions[1]).astype(numpy.uint32) ), ] # Theoretical XRF spectra theoretical_spectra, config = xrf_spectra.xrf_spectra( linegroups, scattergroups, energy=energy, flux=flux, elapsed_time=expo_time, ) # Measured XRF spectra if integral_type: integral_type = numpy.uint32 else: integral_type = None measured_data = apply_dualchannel_signal_processing( theoretical_spectra, elapsed_time=expo_time, counting_noise=counting_noise, integral_type=integral_type, ) # ROI data (theoretical, measured and corrected) roi_data_theory = dict() roi_data_cor = dict() # I0 and LT corrected I0_reference = I0_max for i, roi in enumerate(rois, 1): roi_name = f"roi{i}" idx = Ellipsis, slice(*roi) roi_theory = theoretical_spectra[idx].sum(axis=-1) / I0 * I0_max roi_data_theory[roi_name] = roi_theory roi_meas = measured_data["spectrum"][idx].sum(axis=-1) measured_data[roi_name] = roi_meas cor = I0_reference / I0 * expo_time / measured_data["live_time"] roi_data_cor[roi_name] = cor * roi_meas _save_2d_xrf_scans( filename, scan_number, amesh, coordinates, energy, I0, I0_reference, nmcas, measured_data, roi_data_theory, roi_data_cor, linegroups, scattergroups, config, ) return scan_numbers
class _AmeshAxis(NamedTuple): name: str start: float end: float intervals: int @property def size(self) -> int: return self.intervals + 1 @property def coordinates(self) -> numpy.ndarray: return numpy.linspace(self.start, self.end, self.size) class _Amesh(NamedTuple): fast: _AmeshAxis slow: _AmeshAxis expo_time: float def __str__(self): return f"amesh {self.fast.name} {self.fast.start} {self.fast.end} {self.fast.intervals} {self.slow.name} {self.slow.start} {self.slow.end} {self.slow.intervals} {self.expo_time}" @property def shape(self) -> Tuple[int]: return (self.fast.size, self.slow.size) @property def size(self) -> int: return self.fast.size * self.slow.size @functools.lru_cache(maxsize=1) def _scene_data() -> Tuple[numpy.ndarray, List[str]]: """Data in pixel coordinates on which to interpolate for create amesh data. :returns: array with shape `(nimages, nfast, nslow)`, fast and slow axis name (in that order) """ filename = os.path.join(os.path.dirname(__file__), "ihc.png") scene_data = numpy.transpose(imread(filename), [2, 0, 1]) _mark_fast_axis(scene_data) scene_axes = ["sampz", "sampy"] return scene_data, scene_axes def _mark_fast_axis(channels: numpy.ndarray) -> None: """Modify the image intensities to mark the fast axis (first dimension).""" image_shape = channels[0].shape dstart = image_shape[0] // 10, image_shape[1] // 10 dtick = dstart[0] // 4, dstart[1] // 4 p0 = dstart[0] - dtick[0], dstart[1] - dtick[1] p1 = dstart[0] + dtick[0], dstart[1] + dtick[1] channels[:, p0[0] : p1[0], p0[1] : p1[1]] = 255 dtick = dtick[0] // 2, dtick[1] // 2 dend = image_shape[0] // 2, image_shape[1] // 10 p0 = dstart[0] - dtick[0], dstart[1] - dtick[1] p1 = dend[0] + dtick[0], dend[1] + dtick[1] channels[:, p0[0] : p1[0], p0[1] : p1[1]] = 255 def _amesh_specs( shape: Tuple[int, int], mosaic: Tuple[int, int], max_deviation: float = 0, expo_time: float = 0.1, ) -> List[_Amesh]: """ Generate a list of _Amesh instances based on the provided parameters. :param shape: Number of interpolation points along each axis (fast, slow). :param mosaic: Number of blocks along each axis (fast, slow). :param max_deviation: Maximum deviation from a perfect grid as a fraction of the scan step size. :param expo_time: Exposure time per point in seconds. :return: List of _Amesh instances. """ scene_data, scene_axes = _scene_data() scene_shape = scene_data.shape[1:] fast_axes = _amesh_axes( scene_axes[0], scene_shape[0], shape[0], mosaic[0], max_deviation ) slow_axes = _amesh_axes( scene_axes[1], scene_shape[1], shape[1], mosaic[1], max_deviation ) amesh_list = list() for slow in slow_axes: for fast in fast_axes: amesh_list.append(_Amesh(fast, slow, expo_time)) return amesh_list def _amesh_axes( name: str, scene_size: int, total_size: int, mosaic: int, max_deviation: float ) -> List[_AmeshAxis]: """ Generate a list of _AmeshAxis instances based on the provided parameters. :param name: Axis name. :param scene_size: Total of scene data points. :param total_size: Number of scan points. :param mosaic: Number of scans to span the scene. :param max_deviation: Maximum deviation from a perfect grid as a fraction of the scan step size. """ block_size = total_size // mosaic block_intervals = block_size - 1 scene_block_range = (scene_size - 1) / ( mosaic + (2 * max_deviation / block_intervals) ) interpolate_step_size = scene_block_range / block_intervals scan_border = max_deviation * interpolate_step_size scene_start = scan_border # scene_end = scene_size - 1 - scan_border # assert ((scene_end - scene_start) / scene_block_range - mosaic) < 1e-6 axes = list() for i in range(mosaic): block_start = scene_start + i * scene_block_range block_end = block_start + scene_block_range axes.append( _AmeshAxis( name=name, start=block_start, end=block_end, intervals=block_intervals, ) ) return axes def _save_2d_xrf_scans( filename: str, scan_number: int, amesh: _Amesh, coordinates: List[numpy.ndarray], energy: float, I0: numpy.ndarray, I0_reference: float, nmcas: int, measured_data: Dict[str, numpy.ndarray], roi_data_theory: Dict[str, numpy.ndarray], roi_data_cor: Dict[str, numpy.ndarray], linegroups: List[xrf_spectra.EmissionLineGroup], scattergroups: List[xrf_spectra.ScatterLineGroup], config: ConfigDict, ) -> None: with h5py_utils.File(filename, mode="a") as nxroot: scan_name = f"{scan_number}.1" nxroot.attrs["NX_class"] = "NXroot" nxroot.attrs["creator"] = "ewoksfluo" nxentry = nxroot.require_group(scan_name) nxentry.attrs["NX_class"] = "NXentry" if "title" in nxentry: del nxentry["title"] nxentry["title"] = str(amesh) nxinstrument = nxentry.require_group("instrument") nxinstrument.attrs["NX_class"] = "NXinstrument" measurement = nxentry.require_group("measurement") measurement.attrs["NX_class"] = "NXcollection" fast_coordinates = coordinates[0].flatten(order="F") slow_coordinates = coordinates[1].flatten(order="F") # Positioners for name in ("positioners_start", "positioners_end", "positioners"): group = nxinstrument.require_group(name) group.attrs["NX_class"] = "NXcollection" if "energy" in group: del group["energy"] group["energy"] = energy group["energy"].attrs["units"] = "keV" if name == "positioners": continue if name == "positioners_start": idx = 0 else: idx = -1 if amesh.fast.name in group: del group[amesh.fast.name] group[amesh.fast.name] = fast_coordinates[idx] group[amesh.fast.name].attrs["units"] = "um" if amesh.slow.name in group: del group[amesh.slow.name] group[amesh.slow.name] = slow_coordinates[idx] group[amesh.slow.name].attrs["units"] = "um" positioners = nxinstrument["positioners"] # I0 data nxdetector = nxinstrument.require_group("I0") nxdetector.attrs["NX_class"] = "NXdetector" if "data" in nxdetector: del nxdetector["data"] nxdetector["data"] = I0.flatten(order="F") if "I0" not in measurement: measurement["I0"] = h5py.SoftLink(nxdetector["data"].name) # Fast axis nxpositioner = nxinstrument.require_group(amesh.fast.name) nxpositioner.attrs["NX_class"] = "NXpositioner" if "value" in nxpositioner: del nxpositioner["value"] nxpositioner["value"] = fast_coordinates nxpositioner["value"].attrs["units"] = "um" if amesh.fast.name not in measurement: measurement[amesh.fast.name] = h5py.SoftLink(nxpositioner["value"].name) if amesh.fast.name not in positioners: positioners[amesh.fast.name] = h5py.SoftLink(nxpositioner["value"].name) # Slow axis nxpositioner = nxinstrument.require_group(amesh.slow.name) nxpositioner.attrs["NX_class"] = "NXpositioner" if "value" in nxpositioner: del nxpositioner["value"] nxpositioner["value"] = slow_coordinates nxpositioner["value"].attrs["units"] = "um" if amesh.slow.name not in measurement: measurement[amesh.slow.name] = h5py.SoftLink(nxpositioner["value"].name) if amesh.slow.name not in positioners: positioners[amesh.slow.name] = h5py.SoftLink(nxpositioner["value"].name) # MCA detector for i in range(nmcas): det_name = f"mca{i}" nxdetector = nxinstrument.require_group(det_name) nxdetector.attrs["NX_class"] = "NXdetector" for signal_name, signal_values in measured_data.items(): if signal_name in nxdetector: del nxdetector[signal_name] if signal_name == "spectrum": mca_shape = (amesh.size, signal_values.shape[-1]) nxdetector[signal_name] = signal_values.reshape( mca_shape, order="F" ) if "data" not in nxdetector: nxdetector["data"] = h5py.SoftLink("spectrum") meas_name = det_name else: nxdetector[signal_name] = signal_values.flatten(order="F") meas_name = f"{det_name}_{signal_name}" if meas_name not in measurement: measurement[meas_name] = h5py.SoftLink(nxdetector[signal_name].name) nxprocess = nxentry.require_group("theory") nxprocess.attrs["NX_class"] = "NXprocess" if "I0_reference" in nxprocess: del nxprocess["I0_reference"] nxprocess["I0_reference"] = I0_reference nxnote = nxprocess.require_group("configuration") nxnote.attrs["NX_class"] = "NXnote" if "data" in nxnote: del nxnote["data"] if "type" in nxnote: del nxnote["type"] nxnote["type"] = "application/pymca" nxnote["data"] = config.tostring() nxnote = nxprocess.require_group("description") nxnote.attrs["NX_class"] = "NXnote" if "data" in nxnote: del nxnote["data"] if "type" in nxnote: del nxnote["type"] nxnote["type"] = "text/plain" description = [ "- parameters: peak areas without dead-time", "- parameters_norm: peak areas without dead-time and I0 normalized", "- rois: MCA ROI's without dead-time and I0 normalized (theoretical)", "- rois_norm: MCA ROI's without dead-time and I0 normalized (calculated)", ] nxnote["data"] = "\n".join(description) signals = {f"{g.element}-{g.name}": g.counts for g in linegroups} signals.update({g.name: g.counts for g in scattergroups}) _save_nxdata(amesh, nxprocess, "parameters", signals, positioners) signals = { f"{g.element}-{g.name}": g.counts / I0 * I0_reference for g in linegroups } signals.update({g.name: g.counts / I0 * I0_reference for g in scattergroups}) _save_nxdata(amesh, nxprocess, "parameters_norm", signals, positioners) if roi_data_theory: _save_nxdata(amesh, nxprocess, "rois", roi_data_theory, positioners) if roi_data_cor: _save_nxdata(amesh, nxprocess, "rois_norm", roi_data_cor, positioners) if "end_time" in nxentry: del nxentry["end_time"] nxentry["end_time"] = nexus_utils.now() def _save_nxdata( amesh: _Amesh, parent: h5py.Group, name: str, signals: Dict[str, numpy.ndarray], positioners: h5py.Group, ) -> None: """Saves a set of signals as NeXus data to a given parent group. :param amesh: Scan description. :param parent: The parent `h5py.Group` where the data will be saved. :param name: The name of the dataset to be saved under. :param signals: Dictionary of signal names and their corresponding data arrays. :param positioners: Group containing the positioners (`sampz` and `sampy`) for the scan. """ nxdata = parent.require_group(name) nxdata.attrs["NX_class"] = "NXdata" # nxdata.attrs["interpretation"] = "image" names = list(signals.keys()) nxdata.attrs["signal"] = names[0] if len(names) > 1: nxdata.attrs["auxiliary_signals"] = names[1:] for signal_name, signal_values in signals.items(): if signal_name in nxdata: del nxdata[signal_name] nxdata[signal_name] = signal_values.flatten(order="F") nxdata.attrs["axes"] = [amesh.fast.name, amesh.slow.name] # Order: fast to slow if amesh.fast.name not in nxdata: nxdata[amesh.fast.name] = h5py.SoftLink(positioners[amesh.fast.name].name) if amesh.slow.name not in nxdata: nxdata[amesh.slow.name] = h5py.SoftLink(positioners[amesh.slow.name].name) def _iter_amesh_data( amesh: _Amesh, coordinates: List[numpy.ndarray], rstate: numpy.random.RandomState, nmaps: int = 1, ) -> Iterator[numpy.ndarray]: """Yield random samples of an image scanned with `amesh`. :param amesh: Scan description. :param coordinates: Scan coordinates (fast axis first). :returns: signal (F-order matrix). """ scene_data, scene_axes = _scene_data() scan_axes = [amesh.fast.name, amesh.slow.name] if scan_axes == scene_axes: pass elif scan_axes == scene_axes[::-1]: # Flip fast and slow axis scene_data = numpy.transpose(scene_data, [0, 2, 1]) else: raise ValueError("Must be an amesh scan over 'sampy' and 'sampz'") flat_coordinates = [x.flatten(order="F") for x in coordinates] for _ in range(nmaps): # Random linear combination of the RGB channels which # results in an image with values between 0 and 1 fractions = rstate.uniform(low=0, high=1, size=3) fractions /= 255 * fractions.sum() image = sum(fractions[:, numpy.newaxis, numpy.newaxis] * scene_data) # Interpolate the image (pixel coordinates) on the coordinate grid iimage = ndimage.map_coordinates( image, flat_coordinates, order=1, cval=0, mode="nearest" ) yield iimage.reshape(amesh.shape, order="F") def _amesh_positions( asmesh: _Amesh, rstate: numpy.random.RandomState, max_deviation: float = 0, ) -> List[numpy.ndarray]: """Generates motor positions for an amesh scan with optional deviations (fast axis first).""" positions = numpy.meshgrid( asmesh.fast.coordinates, asmesh.slow.coordinates, indexing="ij" ) if not max_deviation: return positions deviations = [ abs( max_deviation if axis.size <= 1 else (axis.end - axis.start) / (axis.size - 1) * max_deviation ) for axis in [asmesh.fast, asmesh.slow] ] positions = [ values + rstate.uniform(low=-d, high=d, size=asmesh.shape) for values, d in zip(positions, deviations) ] return positions