import logging
import time
from concurrent.futures import Future
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import as_completed
from contextlib import ExitStack
from contextlib import contextmanager
from multiprocessing import Manager
from queue import Queue
from typing import Any
from typing import Dict
from typing import Generator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Union
import numpy
from ewoksdata.data.url import as_dataurl
from silx.io import h5py_utils
from ... import resource_utils
from ..fit import fit_xrf_spectra
from ..fit.types import XRFBatchFitResult
from ..pymca_config import PyMcaXrfConfiguration
from ..pymca_config import modify_pymca_configuration
from ..pymca_config import pymca_configdict_from_file
from ..pymca_config import pymca_configdict_to_model
from ..pymca_config import set_beam_energy
from ._progress import ProgressLogger
from ._save import XrfResultWriter
from ._types import FitArguments
from ._types import FitTask
from ._types import FitUris
from ._types import ProgressStats
logger = logging.getLogger(__name__)
[docs]
def fit_xrf_urls(
fit_io_uris: List[FitUris],
configuration: Union[PyMcaXrfConfiguration, str],
individual_weights: Optional[bool] = None,
positive_peak_areas: Optional[bool] = None,
mlines: Optional[dict] = None,
quantification: Union[Dict[str, float], bool, None] = None,
energy_multiplier: Optional[float] = None,
fast_fitting: Optional[bool] = None,
native_fitting: Optional[bool] = None,
native_legacy_fitting: Optional[bool] = None,
diagnostics: Optional[bool] = None,
max_workers: Optional[int] = None,
block_size: Optional[int] = None,
) -> None:
"""
:param fit_io_uris: List of URI's to XRF spectra with shape `(num_spectra, num_channels)`
to fit and the associated URI to save the fit results.
:param configuration: PyMca configuration.
:param individual_weights: When fitting with weights, use the weight of each
spectrum (slow) instead of the average weight.
:param positive_peak_areas:
:param mlines: elements (keys) which M line group must be replaced by some M subgroups (values). Defaults to None.
:param quantification: Calculate mass fractions from peak area's.
:param energy_multiplier: adds a higher energy bound equal to energy*energy_multiplier to include high-energy peaks. Default: no bound is added.
:param fast_fitting: Fast fitting means fit all spectra by solving a single linear system of equations.
:param native_fitting: Use native PyMca batch processing.
:param native_legacy_fitting: Use legacy native PyMca batch processing.
:param diagnostics: fit model and residuals.
:param max_workers: Number of parallel fitting.
:param block_size: Number of spectra to read and fit at once.
"""
block_size_max = 1000 if fast_fitting else 100
if max_workers is None:
max_workers = resource_utils.get_available_cpus(exclude_current=True)
if block_size is None:
# diagnostics = True, fast_fitting = True, native_fitting = False: x9
# diagnostics = True, fast_fitting = True, native_fitting = True: x4.5
# diagnostics = True, fast_fitting = False, native_fitting = False: x8.5
# diagnostics = True, fast_fitting = False, native_fitting = True: 6.5
shape = (_get_mca_num_channels(fit_io_uris[0].xrf_spectra_uris[0]),)
block_size = min(
resource_utils.array_block_size(
shape, numpy.float64, max_workers, multiplier=9
),
block_size_max,
)
logger.info(
"Fit XRF spectra in blocks of %d in %d workers", block_size, max_workers
)
configuration = _initialize_configuration(
configuration,
mlines=mlines,
quantification=quantification,
fast_fitting=fast_fitting,
)
fit_arguments = FitArguments(
configuration=configuration,
quantification=bool(quantification),
individual_weights=individual_weights,
positive_peak_areas=positive_peak_areas,
energy_multiplier=energy_multiplier,
fast_fitting=fast_fitting,
native_fitting=native_fitting,
native_legacy_fitting=native_legacy_fitting,
diagnostics=diagnostics,
)
with ExitStack() as stack:
progress = stack.enter_context(ProgressLogger())
default_group = "fit" if diagnostics else "parameters"
writers = {
fit_io_uri.output_parent_uri: stack.enter_context(
XrfResultWriter(
fit_io_uri.output_parent_uri, default_group=default_group
)
)
for fit_io_uri in fit_io_uris
}
for writer in writers.values():
writer.save_configuration(configuration)
# Data structures and flow for reading, fitting and writing data:
# - Generate `FitTask` instances from `FitUris` instances.
# - Generate `XRFBatchFitResult` instances from `FitTask` instances.
# - Save `XRFBatchFitResult` instances with to corresponding `XrfResultWriter` instance.
if max_workers == 0:
_sequential_execution(
fit_io_uris,
fit_arguments,
writers,
progress,
block_size,
)
else:
_parallel_execution(
fit_io_uris,
fit_arguments,
writers,
stack,
progress,
block_size,
max_workers,
)
def _sequential_execution(
fit_io_uris: List[FitUris],
fit_arguments: FitArguments,
writers: Dict[str, XrfResultWriter],
progress: ProgressLogger,
block_size: int,
):
for fit_io_uri in fit_io_uris:
for fit_task in _iter_fit_tasks(fit_io_uri, block_size):
progress.accumulate(expected_num_spectra=fit_task.num_spectra)
fit_task, fit_result = _execute_fit_task(fit_task, fit_arguments)
_finish_task(fit_task, fit_result, writers, progress)
def _parallel_execution(
fit_io_uris: List[FitUris],
fit_arguments: FitArguments,
writers: Dict[str, XrfResultWriter],
stack: ExitStack,
progress: ProgressLogger,
block_size: int,
max_workers: int,
):
producers = set()
producer_pool = stack.enter_context(_pool(max_workers, producers))
manager = stack.enter_context(Manager())
fit_tasks = manager.Queue()
for fit_io_uri in fit_io_uris:
future = producer_pool.submit(
_enqueue_fit_tasks, fit_tasks, fit_io_uri, block_size
)
producers.add(future)
fit_futures = set()
fit_pool = stack.enter_context(_pool(max_workers, fit_futures))
while producers or not fit_tasks.empty():
# Submit new fit tasks
while not fit_tasks.empty():
fit_task = fit_tasks.get()
progress.accumulate(expected_num_spectra=fit_task.num_spectra)
fit_futures.add(fit_pool.submit(_execute_fit_task, fit_task, fit_arguments))
# Save finished fit results
for fit_task, fit_result in _purge_futures(fit_futures, "Fitting"):
_finish_task(fit_task, fit_result, writers, progress)
# Clear finished producers
for _ in _purge_futures(producers, "Reading data"):
pass
for fit_task, fit_result in _wait_futures(fit_futures, "Fitting"):
_finish_task(fit_task, fit_result, writers, progress)
def _finish_task(
fit_task: FitTask,
fit_result: XRFBatchFitResult,
writers: Dict[str, XrfResultWriter],
progress: ProgressLogger,
) -> None:
t0 = time.perf_counter()
writer = writers[fit_task.output_parent_uri]
writer.save_fit_result(fit_result, fit_task.start_out, fit_task.stop_out)
t1 = time.perf_counter()
fit_task.progress.seconds_write += t1 - t0
progress.accumulate(
read=fit_task.progress.seconds_read,
process=fit_task.progress.seconds_process,
write=fit_task.progress.seconds_write,
processed_num_spectra=fit_result.num_spectra,
)
@contextmanager
def _pool(
max_workers: int, futures: Sequence[Future]
) -> Generator[ProcessPoolExecutor, None, None]:
with ProcessPoolExecutor(max_workers=max_workers) as producer_pool:
try:
yield producer_pool
except Exception:
for f in futures:
f.cancel()
raise
def _initialize_configuration(
configuration: Union[PyMcaXrfConfiguration, str],
mlines: Optional[dict] = None,
quantification: Union[Dict[str, float], bool, None] = None,
fast_fitting: bool = False,
) -> PyMcaXrfConfiguration:
if isinstance(configuration, str):
configuration = pymca_configdict_to_model(
pymca_configdict_from_file(configuration)
)
else:
configuration = configuration.model_copy(deep=True)
modify_pymca_configuration(
configuration,
energy=None,
energy_multiplier=None,
mlines=mlines,
quantification=quantification,
fast_fitting=fast_fitting,
)
return configuration
def _enqueue_fit_tasks(fit_tasks: Queue, fit_io_uri: FitUris, block_size: int) -> None:
for fit_task in _iter_fit_tasks(fit_io_uri, block_size):
fit_tasks.put(fit_task)
def _iter_fit_tasks(
fit_io_uri: FitUris, block_size: int
) -> Generator[FitTask, None, None]:
offset = 0
t0 = time.perf_counter()
for xrf_spectra_uri, beam_energy_uri in zip(
fit_io_uri.xrf_spectra_uris, fit_io_uri.beam_energy_uris
):
stop_in = None
for start_in, stop_in in _iter_scan_slices(xrf_spectra_uri, block_size):
t1 = time.perf_counter()
fit_task = FitTask(
xrf_spectra_uri=xrf_spectra_uri,
beam_energy_uri=beam_energy_uri,
start_in=start_in,
stop_in=stop_in,
start_out=start_in + offset,
stop_out=stop_in + offset,
output_parent_uri=fit_io_uri.output_parent_uri,
progress=ProgressStats(),
)
fit_task.progress.num_spectra = fit_task.num_spectra
fit_task.progress.seconds_read += t1 - t0
yield fit_task
if stop_in is not None:
offset = stop_in
def _purge_futures(futures: Set[Future], task: str) -> Generator[Any, None, None]:
done = [f for f in futures if f.done()]
for f in done:
try:
yield f.result()
except Exception as ex:
logger.error("%s failed: %s", task, ex)
raise
futures.remove(f)
def _wait_futures(futures: Set[Future], task: str) -> Generator[Any, None, None]:
for f in as_completed(futures):
try:
yield f.result()
except Exception as ex:
logger.error("%s failed: %s", task, ex)
raise
futures.remove(f)
def _execute_fit_task(
fit_task: FitTask, fit_arguments: FitArguments
) -> Tuple[FitTask, XRFBatchFitResult]:
t0 = time.perf_counter()
spectra, energy = _read_data(
fit_task.xrf_spectra_uri,
fit_task.beam_energy_uri,
fit_task.start_in,
fit_task.stop_in,
)
t1 = time.perf_counter()
if energy is None or energy.ndim == 0:
pass
elif energy.ndim == 1:
energy = energy[0]
else:
indices = numpy.unique(energy)
if len(indices) != 1:
raise NotImplementedError(
f"Block has more than on primary beam energy: {energy[indices]}"
)
energy = energy[indices[0]]
fit_task.progress.seconds_read += t1 - t0
if energy is not None:
configuration = _modify_configuration(
fit_arguments.configuration, energy, fit_arguments.energy_multiplier
)
fit_result = fit_xrf_spectra(
spectra,
configuration=configuration,
quantification=fit_arguments.quantification,
individual_weights=fit_arguments.individual_weights,
positive_peak_areas=fit_arguments.positive_peak_areas,
diagnostics=fit_arguments.diagnostics,
fast_fitting=fit_arguments.fast_fitting,
native_fitting=fit_arguments.native_fitting,
native_legacy_fitting=fit_arguments.native_legacy_fitting,
)
t2 = time.perf_counter()
fit_task.progress.seconds_process += t2 - t1
return fit_task, fit_result
def _modify_configuration(
configuration: PyMcaXrfConfiguration,
energy: float,
energy_multiplier: Optional[float],
) -> PyMcaXrfConfiguration:
configuration = configuration.model_copy(deep=True)
set_beam_energy(configuration, energy, energy_multiplier)
return configuration
def _iter_scan_slices(
spectra_uri: str, block_size: int
) -> Generator[Tuple[int, int], None, None]:
# TODO: assumes static HDF5 file for now
url = as_dataurl(spectra_uri)
with h5py_utils.File(url.file_path()) as h5f:
nspectra = len(h5f[url.data_path()])
for start in range(0, nspectra, block_size):
yield start, min(start + block_size, nspectra)
def _get_mca_num_channels(xrf_spectra_uri: str) -> int:
spectra_url = as_dataurl(xrf_spectra_uri)
with h5py_utils.File(spectra_url.file_path()) as h5f_spectra:
dset_spectra = h5f_spectra[spectra_url.data_path()]
return dset_spectra.shape[-1]
def _read_data(
xrf_spectra_uri: str, beam_energy_uri: Optional[str], start: int, stop: int
) -> Generator[Tuple[numpy.ndarray, Optional[numpy.ndarray]], None, None]:
"""Yields arrays with shape `(Nspectra, Nchan)`."""
spectra_url = as_dataurl(xrf_spectra_uri)
if beam_energy_uri:
energy_url = as_dataurl(beam_energy_uri)
with h5py_utils.File(spectra_url.file_path()) as h5f_spectra:
with h5py_utils.File(energy_url.file_path()) as h5f_energy:
dset_spectra = h5f_spectra[spectra_url.data_path()]
spectra = dset_spectra[start:stop]
dset_energy = h5f_energy[energy_url.data_path()]
if dset_energy.size in (0, 1):
energy = dset_energy[()]
else:
energy = dset_energy[start:stop]
return spectra, energy
else:
with h5py_utils.File(spectra_url.file_path()) as h5f_spectra:
dset_spectra = h5f_spectra[spectra_url.data_path()]
spectra = dset_spectra[start:stop]
return spectra, None