import logging
import warnings
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from ewokscore import Task
from ewokscore.model import BaseInputModel
from ewokscore.model import BaseOutputModel
from pydantic import Field
from ...io import hdf5
from ...io import output_uri
from ...xrffit.pymca_config import PyMcaXrfConfiguration
from ..positioner_utils import get_primary_beam_energy_value
from .execute import fit_multi
from .execute import fit_single
logger = logging.getLogger(__name__)
[docs]
class FitSingleScanSingleDetectorOutputs(BaseOutputModel):
bliss_scan_uri: str = Field(
description="Bliss scan URI.", examples=["/data/dataset.h5::/1.1"]
)
output_root_uri: str = Field(
description="Original output root URI received as input.",
examples=[
"/results/dataset.h5",
"/results/dataset.h5::/1.1",
],
)
output_root_group: Optional[str] = Field(
default=None, description="Original output root group received as input."
)
xrf_results_uri: str = Field(
description="Output NXprocess results URI.",
examples=["/results/dataset.h5::/1.1/fit/mca0/results"],
)
detector_name: str = Field(description="Name of the fitted detector.")
[docs]
class FitSingleScanSingleDetector(
Task,
input_model=FitSingleScanSingleDetectorInputs,
output_model=FitSingleScanSingleDetectorOutputs,
):
"""XRF fit of one scan with one detector"""
[docs]
def run(self):
xrf_results_uri = fit_single(**self._fit_parameters())
self.outputs.bliss_scan_uri = self.inputs.bliss_scan_uri
self.outputs.detector_name = self.inputs.detector_name
self.outputs.output_root_uri = self.inputs.output_root_uri
self.outputs.output_root_group = self.inputs.output_root_group
self.outputs.xrf_results_uri = xrf_results_uri
def _fit_parameters(self) -> dict:
params = self.get_input_values()
params["pymca_configuration"] = params.pop("config")
_resolve_energy(params)
_resolve_output_uri(params)
return params
[docs]
class FitSingleScanMultiDetectorOutputs(BaseOutputModel):
bliss_scan_uri: str = Field(
description="Bliss scan URI.", examples=["/data/dataset.h5::/1.1"]
)
output_root_uri: str = Field(
description="Original output root URI received as input.",
examples=[
"/results/dataset.h5",
"/results/dataset.h5::/1.1",
],
)
output_root_group: Optional[str] = Field(
default=None, description="Original output root group received as input."
)
xrf_results_uris: List[str] = Field(
description="Output NXprocess results URI.",
examples=[
"/results/dataset.h5::/1.1/fit/fx_nano_det0/results",
"/results/dataset.h5::/1.1/fit/fx_nano_det1/results",
],
)
detector_names: List[str] = Field(description="Names of the fitted detectors.")
[docs]
class FitSingleScanMultiDetector(
Task,
input_model=FitSingleScanMultiDetectorInputs,
output_model=FitSingleScanMultiDetectorOutputs,
):
"""XRF fit of one scan with multiple detectors"""
[docs]
def run(self):
xrf_results_uris = fit_multi(**self._fit_parameters())
self.outputs.bliss_scan_uri = self.inputs.bliss_scan_uri
self.outputs.detector_names = self.inputs.detector_names
self.outputs.output_root_uri = self.inputs.output_root_uri
self.outputs.output_root_group = self.inputs.output_root_group
self.outputs.xrf_results_uris = xrf_results_uris
def _fit_parameters(self) -> dict:
params = self.get_input_values()
params["pymca_configurations"] = params.pop("configs")
params["bliss_scan_uris"] = [params.pop("bliss_scan_uri")]
_resolve_energies(params)
_resolve_output_uri(params)
return params
[docs]
class FitStackSingleDetectorOutputs(BaseOutputModel):
bliss_scan_uris: List[str] = Field(
description="Bliss scan URI's.",
examples=[["/data/dataset.h5::/1.1", "/data/dataset.h5::/2.1"]],
)
output_root_uri: str = Field(
description="Original output root URI received as input.",
examples=[
"/results/dataset.h5",
"/results/dataset.h5::/1.1",
],
)
output_root_group: Optional[str] = Field(
default=None, description="Original output root group received as input."
)
xrf_results_uri: str = Field(
description="Output NXprocess results URI.",
examples=["/results/dataset.h5::/1.1/fit/fx_nano_det0/results"],
)
detector_name: str = Field(description="Name of the fitted detector.")
[docs]
class FitStackSingleDetector(
Task,
input_model=FitStackSingleDetectorInputs,
output_model=FitStackSingleDetectorOutputs,
):
"""XRF fit of a stack of identical scan with one detector"""
[docs]
def run(self):
xrf_results_uris = fit_multi(**self._fit_parameters())
self.outputs.bliss_scan_uris = self.inputs.bliss_scan_uris
self.outputs.detector_name = self.inputs.detector_name
self.outputs.output_root_uri = self.inputs.output_root_uri
self.outputs.output_root_group = self.inputs.output_root_group
self.outputs.xrf_results_uri = xrf_results_uris[0]
def _fit_parameters(self) -> dict:
params = self.get_input_values()
params["detector_names"] = [params.pop("detector_name")]
params["pymca_configurations"] = [params.pop("config")]
_resolve_energies(params)
_resolve_output_uri(params)
return params
[docs]
class FitStackMultiDetectorOutputs(BaseOutputModel):
bliss_scan_uris: List[str] = Field(
description="Bliss scan URI's.",
examples=[["/data/dataset.h5::/1.1", "/data/dataset.h5::/2.1"]],
)
output_root_uri: str = Field(
description="Original output root URI received as input.",
examples=[
"/results/dataset.h5",
"/results/dataset.h5::/1.1",
],
)
output_root_group: Optional[str] = Field(
default=None, description="Original output root group received as input."
)
xrf_results_uris: List[str] = Field(
description="Output NXprocess results URI.",
examples=[
"/results/dataset.h5::/1.1/fit/fx_nano_det0/results",
"/results/dataset.h5::/1.1/fit/fx_nano_det1/results",
],
)
detector_names: List[str] = Field(description="Names of the fitted detectors.")
[docs]
class FitStackMultiDetector(
Task,
input_model=FitStackMultiDetectorInputs,
output_model=FitStackMultiDetectorOutputs,
):
"""XRF fit of a stack of identical scan with multiple detectors"""
[docs]
def run(self):
xrf_results_uris = fit_multi(**self._fit_parameters())
self.outputs.bliss_scan_uris = self.inputs.bliss_scan_uris
self.outputs.detector_names = self.inputs.detector_names
self.outputs.output_root_uri = self.inputs.output_root_uri
self.outputs.output_root_group = self.inputs.output_root_group
self.outputs.xrf_results_uris = xrf_results_uris
def _fit_parameters(self) -> dict:
params = self.get_input_values()
params["pymca_configurations"] = params.pop("configs")
_resolve_energies(params)
_resolve_output_uri(params)
return params
def _resolve_energy(params: dict) -> None:
energy_name = params.pop("energy_name", None)
energy_uri_template = params.pop("energy_uri_template", None)
params["energy"] = get_primary_beam_energy_value(
params["bliss_scan_uri"],
energy_name=energy_name,
energy_uri_template=energy_uri_template,
search_on_units=False,
)
def _resolve_energies(params: dict) -> None:
energy_name = params.pop("energy_name", None)
energy_uri_template = params.pop("energy_uri_template", None)
params["energies"] = [
get_primary_beam_energy_value(
bliss_scan_uri,
energy_name=energy_name,
energy_uri_template=energy_uri_template,
search_on_units=False,
)
for bliss_scan_uri in params["bliss_scan_uris"]
]
def _resolve_output_uri(params: dict) -> None:
output_root_uri = params.get("output_root_uri")
output_root_group = params.pop("output_root_group", None)
process_uri_template = params.pop("process_uri_template")
if "bliss_scan_uri" in params:
bliss_scan_uri = params.get("bliss_scan_uri")
_, scan_h5path = hdf5.split_h5uri(bliss_scan_uri)
else:
bliss_scan_uris = params.get("bliss_scan_uris")
_, scan_h5path = hdf5.split_h5uri(bliss_scan_uris[0])
output_root_uri = output_uri.compose_full_output_uri(
output_root_uri,
default_output_data_path=scan_h5path,
extra_data_paths=(output_root_group, output_uri.DEFAULT_FIT_NAME),
)
if process_uri_template:
warnings.warn(
"'process_uri_template' is deprecated and will be removed in a future version. ",
DeprecationWarning,
stacklevel=2,
)
logger.warning(
"Ignore process_uri_template=%r and save fit results in %r",
process_uri_template,
output_root_uri,
)
params["output_root_uri"] = output_root_uri