Source code for ewoksfluo.tests.fit.legacy.test_batch_fit
from typing import Dict
import h5py
import numpy
import pytest
from ....io.hdf5 import split_h5uri
from ....xrffit.legacy import outputbuffer_context
from ....xrffit.legacy import perform_batch_fit
from ....xrffit.pymca_config import pymca_configdict_from_model
from ....xrffit.pymca_config import pymca_configdict_to_file
from ...utils import generate
[docs]
@pytest.mark.parametrize("nscans", [1, 2], ids=["1_scan", "2_scans"])
@pytest.mark.parametrize(
"npoints_per_scan", [1, 7, 200], ids=["1_spectrum", "7_spectra", "200_spectra"]
)
@pytest.mark.parametrize("fast_fitting", [True, False], ids=["fast", "slow"])
@pytest.mark.parametrize("linear", [True, False], ids=["linear", "nonlinear"])
@pytest.mark.parametrize("output_handler", ["nexus", "pymca"])
@pytest.mark.parametrize("samefile", [True, False], ids=["same_file", "different_file"])
def test_perform_batch_fit(
tmp_path, nscans, npoints_per_scan, fast_fitting, linear, output_handler, samefile
):
if not fast_fitting and npoints_per_scan > 10:
pytest.skip("too slow, no extra value in testing")
if not samefile and nscans == 1:
pytest.skip("no extra value in testing")
diagnostics = True
figuresofmerit = True
quantification = True
energy = 7.5
energy_multiplier = 10
# Generate data
xrf_spectra_uris, spectra, parameters, config = generate.generate_data(
npoints_per_scan,
energy,
tmp_path=tmp_path,
samefile=samefile,
nscans=nscans,
ndetectors=1,
)
# Single detector:
xrf_spectra_uris = xrf_spectra_uris[0]
spectra = spectra[0]
# Output
output_root_uri = str(tmp_path / "output.h5::/1.1/fit")
xrf_results_uri = str(tmp_path / "output.h5::/1.1/fit/results")
# Configuration
config.fit.linearfitflag = int(linear)
config_filename = str(tmp_path / "config.cfg")
pymca_configdict_to_file(pymca_configdict_from_model(config), config_filename)
# Perform fit
with outputbuffer_context(
output_root_uri,
diagnostics=diagnostics,
figuresofmerit=figuresofmerit,
output_handler=output_handler,
) as output_buffer:
kwargs = {
"xrf_spectra_uris": xrf_spectra_uris,
"pymca_configuration": config_filename,
"output_buffer": output_buffer,
"energy": energy,
"energy_multiplier": energy_multiplier,
"fast_fitting": fast_fitting,
"quantification": quantification,
}
if not samefile and nscans > 1:
with pytest.raises(
ValueError, match="cannot handle scans in different files"
):
perform_batch_fit(**kwargs)
return
perform_batch_fit(**kwargs)
assert output_buffer.xrf_results_uri == xrf_results_uri
if output_handler == "pymca":
# The pymca output handler keeps the data dimensions
if nscans > 1 and npoints_per_scan > 1:
newshape = (nscans, npoints_per_scan, spectra.shape[-1])
spectra = spectra.reshape(newshape)
else:
spectra = spectra[None, ...]
_validate_results(
xrf_results_uri, output_handler, fast_fitting, linear, parameters, spectra
)
def _validate_results(
xrf_results_uri: str,
output_handler: str,
fast_fitting: bool,
linear: bool,
parameters: Dict[str, numpy.ndarray],
spectra: numpy.ndarray,
):
output_file, output_h5path = split_h5uri(xrf_results_uri)
# Validate results
with h5py.File(output_file, mode="r") as h5file:
result_group = h5file[output_h5path]
nparams = 12
nobservations = 1021
# Fit results
if output_handler == "pymca":
# includes *_error softlinks
assert len(result_group["parameters"]) == 2 * nparams
else:
assert len(result_group["parameters"]) == nparams
assert len(result_group["uncertainties"]) == nparams
for name, values in parameters.items():
_check_param_dataset(values, name, result_group)
# Diagnostics
if fast_fitting:
assert set(result_group["diagnostics"]) == {
"nFreeParameters",
"nObservations",
}
else:
assert set(result_group["diagnostics"]) == {
"chisq",
"nFreeParameters",
"nObservations",
}
numpy.testing.assert_array_equal(
result_group["diagnostics/nFreeParameters"][()], nparams
)
numpy.testing.assert_array_equal(
result_group["diagnostics/nObservations"][()], nobservations
)
if fast_fitting:
if output_handler == "pymca":
# + channels and energy
assert len(result_group["derivatives"]) == nparams + 2
else:
# + energy
assert len(result_group["derivatives"]) == nparams + 1
# Fit
if output_handler == "pymca":
if not fast_fitting and not linear:
expected = {
"data",
"model",
"residuals",
"dim2",
"dim0",
"dim1",
}
else:
expected = {
"data",
"model",
"residuals",
"energy",
"channels",
"dim0",
"dim1",
}
else:
if not fast_fitting and not linear:
expected = {"data", "model", "residuals", "dim2"}
else:
expected = {"data", "model", "residuals", "energy"}
assert set(result_group["fit"]) == expected
spectra2 = result_group["fit/data"][()]
numpy.testing.assert_allclose(spectra, spectra2, atol=1e-10)
model = result_group["fit/model"][()]
residuals = result_group["fit/residuals"][()]
residuals2 = spectra - model
mask = ~numpy.isnan(model)
if not fast_fitting:
residuals2 = -residuals2
numpy.testing.assert_allclose(residuals[mask], residuals2[mask], atol=1e-4)
def _check_param_dataset(expected_counts, dset_name, result_group):
fit_counts = result_group[f"parameters/{dset_name}"][()]
if expected_counts.size < 10:
# TODO: does not always work. Weights are disabled but even when they are enabled, it does not work.
fit_errors = 3 * result_group[f"uncertainties/{dset_name}"][()]
diff = numpy.abs(fit_counts - expected_counts)
assert (diff < fit_errors).all()
diff = numpy.abs(numpy.diff(fit_counts) - 50)
assert (diff < 5).all()