Source code for ewoksfluo.tasks.example_data.scan_data

import functools
import os
from typing import Dict
from typing import Iterator
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Tuple

import h5py
import numpy
from silx.io import h5py_utils

try:
    from imageio.v3 import imread
except ImportError:
    try:
        from imageio.v2 import imread
    except ImportError:
        from imageio import imread

from scipy import ndimage

from ...io import hdf5
from ...xrffit.pymca_config import PyMcaXrfConfiguration
from ...xrffit.pymca_config import pymca_configdict_from_model
from .. import nexus_utils
from . import xrf_spectra
from .deadtime import apply_dualchannel_signal_processing
from .monitor import monitor_signal


[docs] def save_2d_xrf_scans( filename: str, emission_line_groups: List[str], first_scan_number: int, shape: Tuple[int, int], mosaic: Tuple[int, int], energy: float = 12, flux: float = 1e7, expo_time: float = 0.1, counting_noise: bool = True, integral_type: bool = True, rois: Sequence = tuple(), nmcas: int = 1, max_deviation: float = 0, seed: Optional[int] = None, ) -> List[int]: """Simulates and saves 2D XRF scans to a NeXus file. :param filename: Output file path. :param emission_line_groups: List of emission line groups as "Element-Group". :param first_scan_number: Scan number of the first scan in the output file. :param shape: Scan shape. :param mosaic: Split shape in this number of blocks. :param energy: Incident X-ray energy in keV. :param flux: Incident X-ray flux in photons per second. :param expo_time: Exposure time per point in seconds. :param counting_noise: If True, adds Poisson noise to data. :param integral_type: If True, data represents integrated intensities. :param rois: Regions of interest as energy ranges. :param ndetectors: Number of detectors. :param max_deviation: Maximum deviation from a perfect grid as a fraction of the scan step size. :param seed: Random seed for reproducibility. :returns: Scan numbers as many as there are blocks in the mosaic. :raises ValueError: If `shape` is not 2D. """ max_deviation = max(max_deviation, 0) rstate = numpy.random.RandomState(seed=seed) I0_max = int(flux * expo_time) emission_line_groups = [s.split("-") for s in emission_line_groups] specs = _amesh_specs( shape=shape, mosaic=mosaic, max_deviation=max_deviation, expo_time=expo_time ) scan_numbers = list() for scan_number, amesh in enumerate(specs, start=first_scan_number): scan_numbers.append(scan_number) coordinates = _amesh_positions(amesh, rstate, max_deviation) # decaying I0 image I0 = (I0_max * monitor_signal(expo_time, amesh.size)).reshape( amesh.shape, order="F" ) # random images with values between 0 and 1 fluoI0fractions = list( _iter_amesh_data( amesh, coordinates, rstate, nmaps=len(emission_line_groups) ) ) scatterI0fractions = list(_iter_amesh_data(amesh, coordinates, rstate, nmaps=2)) # Peak area counts within expo_time seconds linegroups = [ xrf_spectra.EmissionLineGroup( element, group, (I0 * I0fraction).astype(numpy.uint32) ) for I0fraction, (element, group) in zip( fluoI0fractions, emission_line_groups ) ] scattergroups = [ xrf_spectra.ScatterLineGroup( "Compton000", (I0 * scatterI0fractions[0]).astype(numpy.uint32) ), xrf_spectra.ScatterLineGroup( "Peak000", (I0 * scatterI0fractions[1]).astype(numpy.uint32) ), ] # Theoretical XRF spectra theoretical_spectra, config = xrf_spectra.xrf_spectra( linegroups, scattergroups, energy=energy, flux=flux, elapsed_time=expo_time, ) # Measured XRF spectra if integral_type: integral_type = numpy.uint32 else: integral_type = None measured_data = apply_dualchannel_signal_processing( theoretical_spectra, elapsed_time=expo_time, counting_noise=counting_noise, integral_type=integral_type, ) # ROI data (theoretical, measured and corrected) roi_data_theory = dict() roi_data_cor = dict() # I0 and LT corrected I0_reference = I0_max for i, roi in enumerate(rois, 1): roi_name = f"roi{i}" idx = Ellipsis, slice(*roi) roi_theory = theoretical_spectra[idx].sum(axis=-1) / I0 * I0_max roi_data_theory[roi_name] = roi_theory roi_meas = measured_data["spectrum"][idx].sum(axis=-1) measured_data[roi_name] = roi_meas cor = I0_reference / I0 * expo_time / measured_data["live_time"] roi_data_cor[roi_name] = cor * roi_meas _save_2d_xrf_scans( filename, scan_number, amesh, coordinates, energy, I0, I0_reference, nmcas, measured_data, roi_data_theory, roi_data_cor, linegroups, scattergroups, config, ) return scan_numbers
class _AmeshAxis(NamedTuple): name: str start: float end: float intervals: int @property def size(self) -> int: return self.intervals + 1 @property def coordinates(self) -> numpy.ndarray: return numpy.linspace(self.start, self.end, self.size) class _Amesh(NamedTuple): fast: _AmeshAxis slow: _AmeshAxis expo_time: float def __str__(self): return f"amesh {self.fast.name} {self.fast.start} {self.fast.end} {self.fast.intervals} {self.slow.name} {self.slow.start} {self.slow.end} {self.slow.intervals} {self.expo_time}" @property def shape(self) -> Tuple[int]: return (self.fast.size, self.slow.size) @property def size(self) -> int: return self.fast.size * self.slow.size @functools.lru_cache(maxsize=1) def _scene_data() -> Tuple[numpy.ndarray, List[str]]: """Data in pixel coordinates on which to interpolate for create amesh data. :returns: array with shape `(nimages, nfast, nslow)`, fast and slow axis name (in that order) """ filename = os.path.join(os.path.dirname(__file__), "ihc.png") scene_data = numpy.transpose(imread(filename), [2, 0, 1]) _mark_fast_axis(scene_data) scene_axes = ["sampz", "sampy"] return scene_data, scene_axes def _mark_fast_axis(channels: numpy.ndarray) -> None: """Modify the image intensities to mark the fast axis (first dimension).""" image_shape = channels[0].shape dstart = image_shape[0] // 10, image_shape[1] // 10 dtick = dstart[0] // 4, dstart[1] // 4 p0 = dstart[0] - dtick[0], dstart[1] - dtick[1] p1 = dstart[0] + dtick[0], dstart[1] + dtick[1] channels[:, p0[0] : p1[0], p0[1] : p1[1]] = 255 dtick = dtick[0] // 2, dtick[1] // 2 dend = image_shape[0] // 2, image_shape[1] // 10 p0 = dstart[0] - dtick[0], dstart[1] - dtick[1] p1 = dend[0] + dtick[0], dend[1] + dtick[1] channels[:, p0[0] : p1[0], p0[1] : p1[1]] = 255 def _amesh_specs( shape: Tuple[int, int], mosaic: Tuple[int, int], max_deviation: float = 0, expo_time: float = 0.1, ) -> List[_Amesh]: """ Generate a list of _Amesh instances based on the provided parameters. :param shape: Number of interpolation points along each axis (fast, slow). :param mosaic: Number of blocks along each axis (fast, slow). :param max_deviation: Maximum deviation from a perfect grid as a fraction of the scan step size. :param expo_time: Exposure time per point in seconds. :return: List of _Amesh instances. """ scene_data, scene_axes = _scene_data() scene_shape = scene_data.shape[1:] fast_axes = _amesh_axes( scene_axes[0], scene_shape[0], shape[0], mosaic[0], max_deviation ) slow_axes = _amesh_axes( scene_axes[1], scene_shape[1], shape[1], mosaic[1], max_deviation ) amesh_list = list() for slow in slow_axes: for fast in fast_axes: amesh_list.append(_Amesh(fast, slow, expo_time)) return amesh_list def _amesh_axes( name: str, scene_size: int, total_size: int, mosaic: int, max_deviation: float ) -> List[_AmeshAxis]: """ Generate a list of _AmeshAxis instances based on the provided parameters. :param name: Axis name. :param scene_size: Total of scene data points. :param total_size: Number of scan points. :param mosaic: Number of scans to span the scene. :param max_deviation: Maximum deviation from a perfect grid as a fraction of the scan step size. """ block_size = total_size // mosaic block_intervals = block_size - 1 scene_block_range = (scene_size - 1) / ( mosaic + (2 * max_deviation / block_intervals) ) interpolate_step_size = scene_block_range / block_intervals scan_border = max_deviation * interpolate_step_size scene_start = scan_border # scene_end = scene_size - 1 - scan_border # assert ((scene_end - scene_start) / scene_block_range - mosaic) < 1e-6 axes = list() for i in range(mosaic): block_start = scene_start + i * scene_block_range block_end = block_start + scene_block_range axes.append( _AmeshAxis( name=name, start=block_start, end=block_end, intervals=block_intervals, ) ) return axes def _save_2d_xrf_scans( filename: str, scan_number: int, amesh: _Amesh, coordinates: List[numpy.ndarray], energy: float, I0: numpy.ndarray, I0_reference: float, nmcas: int, measured_data: Dict[str, numpy.ndarray], roi_data_theory: Dict[str, numpy.ndarray], roi_data_cor: Dict[str, numpy.ndarray], linegroups: List[xrf_spectra.EmissionLineGroup], scattergroups: List[xrf_spectra.ScatterLineGroup], pymca_configuration: PyMcaXrfConfiguration, ) -> None: with h5py_utils.File(filename, mode="a") as nxroot: scan_name = f"{scan_number}.1" nxroot.attrs["NX_class"] = "NXroot" nxroot.attrs["creator"] = "ewoksfluo" nxentry = nxroot.require_group(scan_name) nxentry.attrs["NX_class"] = "NXentry" if "title" in nxentry: del nxentry["title"] nxentry["title"] = str(amesh) nxinstrument = nxentry.require_group("instrument") nxinstrument.attrs["NX_class"] = "NXinstrument" measurement = nxentry.require_group("measurement") measurement.attrs["NX_class"] = "NXcollection" fast_coordinates = coordinates[0].flatten(order="F") slow_coordinates = coordinates[1].flatten(order="F") # Positioners for name in ("positioners_start", "positioners_end", "positioners"): group = nxinstrument.require_group(name) group.attrs["NX_class"] = "NXcollection" if "energy" in group: del group["energy"] group["energy"] = energy group["energy"].attrs["units"] = "keV" if name == "positioners": continue if name == "positioners_start": idx = 0 else: idx = -1 if amesh.fast.name in group: del group[amesh.fast.name] group[amesh.fast.name] = fast_coordinates[idx] group[amesh.fast.name].attrs["units"] = "um" if amesh.slow.name in group: del group[amesh.slow.name] group[amesh.slow.name] = slow_coordinates[idx] group[amesh.slow.name].attrs["units"] = "um" positioners = nxinstrument["positioners"] # I0 data nxdetector = nxinstrument.require_group("I0") nxdetector.attrs["NX_class"] = "NXdetector" if "data" in nxdetector: del nxdetector["data"] nxdetector["data"] = I0.flatten(order="F") if "I0" not in measurement: measurement["I0"] = h5py.SoftLink(nxdetector["data"].name) # Fast axis nxpositioner = nxinstrument.require_group(amesh.fast.name) nxpositioner.attrs["NX_class"] = "NXpositioner" if "value" in nxpositioner: del nxpositioner["value"] nxpositioner["value"] = fast_coordinates nxpositioner["value"].attrs["units"] = "um" if amesh.fast.name not in measurement: measurement[amesh.fast.name] = h5py.SoftLink(nxpositioner["value"].name) if amesh.fast.name not in positioners: positioners[amesh.fast.name] = h5py.SoftLink(nxpositioner["value"].name) # Slow axis nxpositioner = nxinstrument.require_group(amesh.slow.name) nxpositioner.attrs["NX_class"] = "NXpositioner" if "value" in nxpositioner: del nxpositioner["value"] nxpositioner["value"] = slow_coordinates nxpositioner["value"].attrs["units"] = "um" if amesh.slow.name not in measurement: measurement[amesh.slow.name] = h5py.SoftLink(nxpositioner["value"].name) if amesh.slow.name not in positioners: positioners[amesh.slow.name] = h5py.SoftLink(nxpositioner["value"].name) # MCA detector for i in range(nmcas): det_name = f"mca{i}" nxdetector = nxinstrument.require_group(det_name) nxdetector.attrs["NX_class"] = "NXdetector" for signal_name, signal_values in measured_data.items(): if signal_name in nxdetector: del nxdetector[signal_name] if signal_name == "spectrum": mca_shape = (amesh.size, signal_values.shape[-1]) nxdetector[signal_name] = signal_values.reshape( mca_shape, order="F" ) if "data" not in nxdetector: nxdetector["data"] = h5py.SoftLink("spectrum") meas_name = det_name else: nxdetector[signal_name] = signal_values.flatten(order="F") meas_name = f"{det_name}_{signal_name}" if meas_name not in measurement: measurement[meas_name] = h5py.SoftLink(nxdetector[signal_name].name) nxprocess = nxentry.require_group("theory") nxprocess.attrs["NX_class"] = "NXprocess" if "I0_reference" in nxprocess: del nxprocess["I0_reference"] nxprocess["I0_reference"] = I0_reference nxnote = nxprocess.require_group("configuration") nxnote.attrs["NX_class"] = "NXnote" if "data" in nxnote: del nxnote["data"] if "type" in nxnote: del nxnote["type"] nxnote["type"] = "application/pymca" nxnote["data"] = pymca_configdict_from_model(pymca_configuration).tostring() nxnote = nxprocess.require_group("description") nxnote.attrs["NX_class"] = "NXnote" if "data" in nxnote: del nxnote["data"] if "type" in nxnote: del nxnote["type"] nxnote["type"] = "text/plain" description = [ "- parameters: peak areas without dead-time", "- parameters_norm: peak areas without dead-time and I0 normalized", "- rois: MCA ROI's without dead-time and I0 normalized (theoretical)", "- rois_norm: MCA ROI's without dead-time and I0 normalized (calculated)", ] nxnote["data"] = "\n".join(description) signals = {f"{g.element}-{g.name}": g.counts for g in linegroups} signals.update({g.name: g.counts for g in scattergroups}) _save_nxdata(amesh, nxprocess, "parameters", signals, positioners) signals = { f"{g.element}-{g.name}": g.counts / I0 * I0_reference for g in linegroups } signals.update({g.name: g.counts / I0 * I0_reference for g in scattergroups}) _save_nxdata(amesh, nxprocess, "parameters_norm", signals, positioners) if roi_data_theory: _save_nxdata(amesh, nxprocess, "rois", roi_data_theory, positioners) if roi_data_cor: _save_nxdata(amesh, nxprocess, "rois_norm", roi_data_cor, positioners) if "end_time" in nxentry: del nxentry["end_time"] nxentry["end_time"] = nexus_utils.now() def _save_nxdata( amesh: _Amesh, parent: hdf5.GroupType, name: str, signals: Dict[str, numpy.ndarray], positioners: hdf5.GroupType, ) -> None: """Saves a set of signals as NeXus data to a given parent group. :param amesh: Scan description. :param parent: The parent `h5py.Group` where the data will be saved. :param name: The name of the dataset to be saved under. :param signals: Dictionary of signal names and their corresponding data arrays. :param positioners: Group containing the positioners (`sampz` and `sampy`) for the scan. """ nxdata = parent.require_group(name) nxdata.attrs["NX_class"] = "NXdata" names = list(signals.keys()) nxdata.attrs["signal"] = names[0] if len(names) > 1: nxdata.attrs["auxiliary_signals"] = names[1:] for signal_name, signal_values in signals.items(): if signal_name in nxdata: del nxdata[signal_name] nxdata[signal_name] = signal_values.flatten(order="F") nxdata.attrs["axes"] = [amesh.fast.name, amesh.slow.name] # Order: fast to slow if amesh.fast.name not in nxdata: nxdata[amesh.fast.name] = h5py.SoftLink(positioners[amesh.fast.name].name) if amesh.slow.name not in nxdata: nxdata[amesh.slow.name] = h5py.SoftLink(positioners[amesh.slow.name].name) def _iter_amesh_data( amesh: _Amesh, coordinates: List[numpy.ndarray], rstate: numpy.random.RandomState, nmaps: int = 1, ) -> Iterator[numpy.ndarray]: """Yield random samples of an image scanned with `amesh`. :param amesh: Scan description. :param coordinates: Scan coordinates (fast axis first). :returns: signal (F-order matrix). """ scene_data, scene_axes = _scene_data() scan_axes = [amesh.fast.name, amesh.slow.name] if scan_axes == scene_axes: pass elif scan_axes == scene_axes[::-1]: # Flip fast and slow axis scene_data = numpy.transpose(scene_data, [0, 2, 1]) else: raise ValueError("Must be an amesh scan over 'sampy' and 'sampz'") flat_coordinates = [x.flatten(order="F") for x in coordinates] for _ in range(nmaps): # Random linear combination of the RGB channels which # results in an image with values between 0 and 1 fractions = rstate.uniform(low=0, high=1, size=3) fractions /= 255 * fractions.sum() image = sum(fractions[:, numpy.newaxis, numpy.newaxis] * scene_data) # Interpolate the image (pixel coordinates) on the coordinate grid iimage = ndimage.map_coordinates( image, flat_coordinates, order=1, cval=0, mode="nearest" ) yield iimage.reshape(amesh.shape, order="F") def _amesh_positions( asmesh: _Amesh, rstate: numpy.random.RandomState, max_deviation: float = 0, ) -> List[numpy.ndarray]: """Generates motor positions for an amesh scan with optional deviations (fast axis first).""" positions = numpy.meshgrid( asmesh.fast.coordinates, asmesh.slow.coordinates, indexing="ij" ) if not max_deviation: return positions deviations = [ abs( max_deviation if axis.size <= 1 else (axis.end - axis.start) / (axis.size - 1) * max_deviation ) for axis in [asmesh.fast, asmesh.slow] ] positions = [ values + rstate.uniform(low=-d, high=d, size=asmesh.shape) for values, d in zip(positions, deviations) ] return positions