from typing import Dict
import h5py
import numpy
import pytest
from ...io.hdf5 import split_h5uri
from ...xrffit import batch
from ..utils import generate_data
[docs]
@pytest.mark.parametrize("nscans", [1, 2], ids=["1_scan", "2_scans"])
@pytest.mark.parametrize("ndetectors", [1, 2], ids=["single_det", "multi_det"])
@pytest.mark.parametrize(
"npoints_per_scan", [1, 7, 200], ids=["1_spectrum", "7_spectra", "200_spectra"]
)
@pytest.mark.parametrize("fast_fitting", [True, False], ids=["fast", "slow"])
@pytest.mark.parametrize("linear", [True, False], ids=["linear", "nonlinear"])
@pytest.mark.parametrize("samefile", [True, False], ids=["same_file", "different_file"])
def test_fit_xrf_urls(
tmp_path, nscans, ndetectors, npoints_per_scan, fast_fitting, linear, samefile
):
if not fast_fitting and npoints_per_scan > 10:
pytest.skip("too slow, no extra value in testing")
if not samefile and nscans == 1:
pytest.skip("no extra value in testing")
diagnostics = True
quantification = True
energy = 7.5
energy_multiplier = 10
# Generate data
xrf_spectra_uris, spectra, parameters, configuration = generate_data(
npoints_per_scan,
energy,
tmp_path=tmp_path,
samefile=samefile,
nscans=nscans,
ndetectors=ndetectors,
)
# Raw data URI's
detector_uris = {}
for lst in xrf_spectra_uris:
for xrf_spectra_uri in lst:
file_path, data_path = xrf_spectra_uri.split("::")
data_parts = data_path.split("/")
bliss_scan_uri = f"{file_path}::/{'/'.join(data_parts[:2])}"
detector_name = data_parts[-1]
lst = detector_uris.setdefault(detector_name, [])
lst.append(bliss_scan_uri)
fit_io_uris = []
output_file = tmp_path / "output.h5"
output_root_uri = f"{output_file}::/1.1"
for detector_name, bliss_scan_uris in detector_uris.items():
fit_io_uri = batch.FitUris(
bliss_scan_uris=bliss_scan_uris,
detector_name=detector_name,
energy_name="energy",
output_root_uri=output_root_uri,
xrf_spectra_uri_template="measurement/{}",
energy_uri_template="instrument/positioners_start/{}",
)
fit_io_uris.append(fit_io_uri)
# Configuration
configuration.fit.linearfitflag = int(linear)
# Fit
batch.fit_xrf_urls(
fit_io_uris,
configuration,
diagnostics=diagnostics,
quantification=quantification,
energy_multiplier=energy_multiplier,
fast_fitting=fast_fitting,
)
# Validate
for detector_name, detector_spectra in zip(detector_uris, spectra):
xrf_results_uri = f"{output_root_uri}/fit/{detector_name}/results"
_validate_results(
xrf_results_uri, fast_fitting, linear, parameters, detector_spectra
)
def _validate_results(
xrf_results_uri: str,
fast_fitting: bool,
linear: bool,
parameters: Dict[str, numpy.ndarray],
spectra: numpy.ndarray,
):
output_file, output_h5path = split_h5uri(xrf_results_uri)
with h5py.File(output_file, mode="r") as h5file:
result_group = h5file[output_h5path]
npeakareas = 12
nnonlinear = 12
if fast_fitting:
nparams = npeakareas
else:
nparams = npeakareas + nnonlinear
expected = {"parameters", "uncertainties", "mass_fractions", "fit"}
if fast_fitting:
expected.add("derivatives")
assert set(result_group) == expected
# Fit results
assert len(result_group["parameters"]) == nparams
assert len(result_group["uncertainties"]) == nparams
for name, values in parameters.items():
_check_param_dataset(values, name, result_group)
if fast_fitting:
assert len(result_group["derivatives"]) == npeakareas + 1
# Fit
expected = {"data", "model", "residuals", "energy"}
assert set(result_group["fit"]) == expected
spectra2 = result_group["fit/data"][()]
numpy.testing.assert_allclose(spectra, spectra2, atol=1e-10)
model = result_group["fit/model"][()]
residuals = result_group["fit/residuals"][()]
residuals2 = spectra - model
mask = ~numpy.isnan(model)
if not fast_fitting:
residuals2 = -residuals2
numpy.testing.assert_allclose(residuals[mask], residuals2[mask], atol=1e-4)
def _check_param_dataset(expected_counts, dset_name, result_group):
fit_counts = result_group[f"parameters/{dset_name}"][()]
if expected_counts.size < 10:
# TODO: does not always work. Weights are disabled but even when they are enabled, it does not work.
fit_errors = 3 * result_group[f"uncertainties/{dset_name}"][()]
diff = numpy.abs(fit_counts - expected_counts)
assert (diff < fit_errors).all()
diff = numpy.abs(numpy.diff(fit_counts) - 50)
assert (diff < 5).all()