Source code for ewoksfluo.xrffit.batch

import logging
import time
from concurrent.futures import Future
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import as_completed
from contextlib import ExitStack
from contextlib import contextmanager
from multiprocessing import Manager
from queue import Queue
from typing import Any
from typing import Dict
from typing import Generator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Union

import numpy
from ewoksdata.data.url import as_dataurl
from silx.io import h5py_utils

from ... import resource_utils
from ..fit import fit_xrf_spectra
from ..fit.types import XRFBatchFitResult
from ..pymca_config import PyMcaXrfConfiguration
from ..pymca_config import modify_pymca_configuration
from ..pymca_config import pymca_configdict_from_file
from ..pymca_config import pymca_configdict_to_model
from ..pymca_config import set_beam_energy
from ._progress import ProgressLogger
from ._save import XrfResultWriter
from ._types import FitArguments
from ._types import FitTask
from ._types import FitUris
from ._types import ProgressStats

logger = logging.getLogger(__name__)


[docs] def fit_xrf_urls( fit_io_uris: List[FitUris], configuration: Union[PyMcaXrfConfiguration, str], individual_weights: Optional[bool] = None, positive_peak_areas: Optional[bool] = None, mlines: Optional[dict] = None, quantification: Union[Dict[str, float], bool, None] = None, energy_multiplier: Optional[float] = None, fast_fitting: Optional[bool] = None, native_fitting: Optional[bool] = None, native_legacy_fitting: Optional[bool] = None, diagnostics: Optional[bool] = None, max_workers: Optional[int] = None, block_size: Optional[int] = None, ) -> None: """ :param fit_io_uris: List of URI's to XRF spectra with shape `(num_spectra, num_channels)` to fit and the associated URI to save the fit results. :param configuration: PyMca configuration. :param individual_weights: When fitting with weights, use the weight of each spectrum (slow) instead of the average weight. :param positive_peak_areas: :param mlines: elements (keys) which M line group must be replaced by some M subgroups (values). Defaults to None. :param quantification: Calculate mass fractions from peak area's. :param energy_multiplier: adds a higher energy bound equal to energy*energy_multiplier to include high-energy peaks. Default: no bound is added. :param fast_fitting: Fast fitting means fit all spectra by solving a single linear system of equations. :param native_fitting: Use native PyMca batch processing. :param native_legacy_fitting: Use legacy native PyMca batch processing. :param diagnostics: fit model and residuals. :param max_workers: Number of parallel fitting. :param block_size: Number of spectra to read and fit at once. """ block_size_max = 1000 if fast_fitting else 100 if max_workers is None: max_workers = resource_utils.get_available_cpus(exclude_current=True) if block_size is None: # diagnostics = True, fast_fitting = True, native_fitting = False: x9 # diagnostics = True, fast_fitting = True, native_fitting = True: x4.5 # diagnostics = True, fast_fitting = False, native_fitting = False: x8.5 # diagnostics = True, fast_fitting = False, native_fitting = True: 6.5 shape = (_get_mca_num_channels(fit_io_uris[0].xrf_spectra_uris[0]),) block_size = min( resource_utils.array_block_size( shape, numpy.float64, max_workers, multiplier=9 ), block_size_max, ) logger.info( "Fit XRF spectra in blocks of %d in %d workers", block_size, max_workers ) configuration = _initialize_configuration( configuration, mlines=mlines, quantification=quantification, fast_fitting=fast_fitting, ) fit_arguments = FitArguments( configuration=configuration, quantification=bool(quantification), individual_weights=individual_weights, positive_peak_areas=positive_peak_areas, energy_multiplier=energy_multiplier, fast_fitting=fast_fitting, native_fitting=native_fitting, native_legacy_fitting=native_legacy_fitting, diagnostics=diagnostics, ) with ExitStack() as stack: progress = stack.enter_context(ProgressLogger()) default_group = "fit" if diagnostics else "parameters" writers = { fit_io_uri.output_parent_uri: stack.enter_context( XrfResultWriter( fit_io_uri.output_parent_uri, default_group=default_group ) ) for fit_io_uri in fit_io_uris } for writer in writers.values(): writer.save_configuration(configuration) # Data structures and flow for reading, fitting and writing data: # - Generate `FitTask` instances from `FitUris` instances. # - Generate `XRFBatchFitResult` instances from `FitTask` instances. # - Save `XRFBatchFitResult` instances with to corresponding `XrfResultWriter` instance. if max_workers == 0: _sequential_execution( fit_io_uris, fit_arguments, writers, progress, block_size, ) else: _parallel_execution( fit_io_uris, fit_arguments, writers, stack, progress, block_size, max_workers, )
def _sequential_execution( fit_io_uris: List[FitUris], fit_arguments: FitArguments, writers: Dict[str, XrfResultWriter], progress: ProgressLogger, block_size: int, ): for fit_io_uri in fit_io_uris: for fit_task in _iter_fit_tasks(fit_io_uri, block_size): progress.accumulate(expected_num_spectra=fit_task.num_spectra) fit_task, fit_result = _execute_fit_task(fit_task, fit_arguments) _finish_task(fit_task, fit_result, writers, progress) def _parallel_execution( fit_io_uris: List[FitUris], fit_arguments: FitArguments, writers: Dict[str, XrfResultWriter], stack: ExitStack, progress: ProgressLogger, block_size: int, max_workers: int, ): producers = set() producer_pool = stack.enter_context(_pool(max_workers, producers)) manager = stack.enter_context(Manager()) fit_tasks = manager.Queue() for fit_io_uri in fit_io_uris: future = producer_pool.submit( _enqueue_fit_tasks, fit_tasks, fit_io_uri, block_size ) producers.add(future) fit_futures = set() fit_pool = stack.enter_context(_pool(max_workers, fit_futures)) while producers or not fit_tasks.empty(): # Submit new fit tasks while not fit_tasks.empty(): fit_task = fit_tasks.get() progress.accumulate(expected_num_spectra=fit_task.num_spectra) fit_futures.add(fit_pool.submit(_execute_fit_task, fit_task, fit_arguments)) # Save finished fit results for fit_task, fit_result in _purge_futures(fit_futures, "Fitting"): _finish_task(fit_task, fit_result, writers, progress) # Clear finished producers for _ in _purge_futures(producers, "Reading data"): pass for fit_task, fit_result in _wait_futures(fit_futures, "Fitting"): _finish_task(fit_task, fit_result, writers, progress) def _finish_task( fit_task: FitTask, fit_result: XRFBatchFitResult, writers: Dict[str, XrfResultWriter], progress: ProgressLogger, ) -> None: t0 = time.perf_counter() writer = writers[fit_task.output_parent_uri] writer.save_fit_result(fit_result, fit_task.start_out, fit_task.stop_out) t1 = time.perf_counter() fit_task.progress.seconds_write += t1 - t0 progress.accumulate( read=fit_task.progress.seconds_read, process=fit_task.progress.seconds_process, write=fit_task.progress.seconds_write, processed_num_spectra=fit_result.num_spectra, ) @contextmanager def _pool( max_workers: int, futures: Sequence[Future] ) -> Generator[ProcessPoolExecutor, None, None]: with ProcessPoolExecutor(max_workers=max_workers) as producer_pool: try: yield producer_pool except Exception: for f in futures: f.cancel() raise def _initialize_configuration( configuration: Union[PyMcaXrfConfiguration, str], mlines: Optional[dict] = None, quantification: Union[Dict[str, float], bool, None] = None, fast_fitting: bool = False, ) -> PyMcaXrfConfiguration: if isinstance(configuration, str): configuration = pymca_configdict_to_model( pymca_configdict_from_file(configuration) ) else: configuration = configuration.model_copy(deep=True) modify_pymca_configuration( configuration, energy=None, energy_multiplier=None, mlines=mlines, quantification=quantification, fast_fitting=fast_fitting, ) return configuration def _enqueue_fit_tasks(fit_tasks: Queue, fit_io_uri: FitUris, block_size: int) -> None: for fit_task in _iter_fit_tasks(fit_io_uri, block_size): fit_tasks.put(fit_task) def _iter_fit_tasks( fit_io_uri: FitUris, block_size: int ) -> Generator[FitTask, None, None]: offset = 0 t0 = time.perf_counter() for xrf_spectra_uri, beam_energy_uri in zip( fit_io_uri.xrf_spectra_uris, fit_io_uri.beam_energy_uris ): stop_in = None for start_in, stop_in in _iter_scan_slices(xrf_spectra_uri, block_size): t1 = time.perf_counter() fit_task = FitTask( xrf_spectra_uri=xrf_spectra_uri, beam_energy_uri=beam_energy_uri, start_in=start_in, stop_in=stop_in, start_out=start_in + offset, stop_out=stop_in + offset, output_parent_uri=fit_io_uri.output_parent_uri, progress=ProgressStats(), ) fit_task.progress.num_spectra = fit_task.num_spectra fit_task.progress.seconds_read += t1 - t0 yield fit_task if stop_in is not None: offset = stop_in def _purge_futures(futures: Set[Future], task: str) -> Generator[Any, None, None]: done = [f for f in futures if f.done()] for f in done: try: yield f.result() except Exception as ex: logger.error("%s failed: %s", task, ex) raise futures.remove(f) def _wait_futures(futures: Set[Future], task: str) -> Generator[Any, None, None]: for f in as_completed(futures): try: yield f.result() except Exception as ex: logger.error("%s failed: %s", task, ex) raise futures.remove(f) def _execute_fit_task( fit_task: FitTask, fit_arguments: FitArguments ) -> Tuple[FitTask, XRFBatchFitResult]: t0 = time.perf_counter() spectra, energy = _read_data( fit_task.xrf_spectra_uri, fit_task.beam_energy_uri, fit_task.start_in, fit_task.stop_in, ) t1 = time.perf_counter() if energy is None or energy.ndim == 0: pass elif energy.ndim == 1: energy = energy[0] else: indices = numpy.unique(energy) if len(indices) != 1: raise NotImplementedError( f"Block has more than on primary beam energy: {energy[indices]}" ) energy = energy[indices[0]] fit_task.progress.seconds_read += t1 - t0 if energy is not None: configuration = _modify_configuration( fit_arguments.configuration, energy, fit_arguments.energy_multiplier ) fit_result = fit_xrf_spectra( spectra, configuration=configuration, quantification=fit_arguments.quantification, individual_weights=fit_arguments.individual_weights, positive_peak_areas=fit_arguments.positive_peak_areas, diagnostics=fit_arguments.diagnostics, fast_fitting=fit_arguments.fast_fitting, native_fitting=fit_arguments.native_fitting, native_legacy_fitting=fit_arguments.native_legacy_fitting, ) t2 = time.perf_counter() fit_task.progress.seconds_process += t2 - t1 return fit_task, fit_result def _modify_configuration( configuration: PyMcaXrfConfiguration, energy: float, energy_multiplier: Optional[float], ) -> PyMcaXrfConfiguration: configuration = configuration.model_copy(deep=True) set_beam_energy(configuration, energy, energy_multiplier) return configuration def _iter_scan_slices( spectra_uri: str, block_size: int ) -> Generator[Tuple[int, int], None, None]: # TODO: assumes static HDF5 file for now url = as_dataurl(spectra_uri) with h5py_utils.File(url.file_path()) as h5f: nspectra = len(h5f[url.data_path()]) for start in range(0, nspectra, block_size): yield start, min(start + block_size, nspectra) def _get_mca_num_channels(xrf_spectra_uri: str) -> int: spectra_url = as_dataurl(xrf_spectra_uri) with h5py_utils.File(spectra_url.file_path()) as h5f_spectra: dset_spectra = h5f_spectra[spectra_url.data_path()] return dset_spectra.shape[-1] def _read_data( xrf_spectra_uri: str, beam_energy_uri: Optional[str], start: int, stop: int ) -> Generator[Tuple[numpy.ndarray, Optional[numpy.ndarray]], None, None]: """Yields arrays with shape `(Nspectra, Nchan)`.""" spectra_url = as_dataurl(xrf_spectra_uri) if beam_energy_uri: energy_url = as_dataurl(beam_energy_uri) with h5py_utils.File(spectra_url.file_path()) as h5f_spectra: with h5py_utils.File(energy_url.file_path()) as h5f_energy: dset_spectra = h5f_spectra[spectra_url.data_path()] spectra = dset_spectra[start:stop] dset_energy = h5f_energy[energy_url.data_path()] if dset_energy.size in (0, 1): energy = dset_energy[()] else: energy = dset_energy[start:stop] return spectra, energy else: with h5py_utils.File(spectra_url.file_path()) as h5f_spectra: dset_spectra = h5f_spectra[spectra_url.data_path()] spectra = dset_spectra[start:stop] return spectra, None