import os
import logging
from dataclasses import dataclass
from typing import Dict, Union, Tuple, List, Any, Optional
import h5py
import numpy
from pint import Quantity
from numpy.typing import DTypeLike
from .hdf5 import join_h5url
from .hdf5 import split_h5uri
from ..units import unit_registry
from ..units import units_as_str
from ..math.expression import expression_variables
from ..math.expression import eval_expression
logger = logging.getLogger(__name__)
[docs]
def concatenate_bliss_scans(
bliss_scan_uris: List[str],
output_root_uri: str,
virtual_axes: Optional[Dict[str, str]] = None,
axes_units: Optional[Dict[str, str]] = None,
start_var: str = "<",
end_var: str = ">",
) -> str:
"""Concatenate Bliss scans in a virtual scan that looks exactly like a Bliss scan.
This method cannot handle scans that are interrupted, except when it is the last scan to be concatenated.
:param bliss_scan_uris: scans to concatenate
:param virtual_axes: virtual motors. For example `{"sy": "<samy>+<sampy>/1000", "sz": "<samz>+<sampz>/1000"}`
:param axes_units: axis units. For example `{"samy": "mm", "sampy": "um"}`
:param start_var: marks the start of a variable name in `virtual_axes` expressions
:param end_var: marks the end of a variable name in `virtual_axes` expressions
:returns: URI to the virtual concatenated scan
"""
if axes_units is None:
axes_units = dict()
out_filename, out_path = split_h5uri(output_root_uri)
if not out_path:
out_path = "/1.1"
output_root_uri = join_h5url(out_filename, out_path)
# Return when out_path already exists
with h5py.File(out_filename, mode="a") as out_root:
if out_path in out_root:
logger.warning("%s already exists", output_root_uri)
return output_root_uri
# Compile data to concatenate
in_scan = None
datasets = None
for bliss_scan_uri in bliss_scan_uris:
logger.debug("Parse %s for concatenation", bliss_scan_uri)
in_filename, in_path = split_h5uri(bliss_scan_uri)
with h5py.File(in_filename, mode="r", locking=False) as in_root:
in_scani = in_root[in_path]
if in_scan is None:
in_scan, datasets = _parse_hdf5_group(in_scani)
else:
_append_H5Datasets(in_scani, datasets)
logger.debug("Save concatenation in %s", output_root_uri)
scalar_repeats = _trim_datasets(datasets)
with h5py.File(out_filename, mode="a") as out_root:
out_scan = out_root.create_group(out_path)
_save_hdf5_group(out_scan, in_scan, scalar_repeats)
if virtual_axes:
virtual_axis_datasets = _add_virtual_axes(
in_scan, virtual_axes, axes_units, start_var=start_var, end_var=end_var
)
with h5py.File(out_filename, mode="r") as out_root:
out_scan = out_root[out_path]
for item in virtual_axis_datasets:
_resolve_virtual_axis(item, out_scan, axes_units)
with h5py.File(out_filename, mode="a") as out_root:
out_scan = out_root[out_path]
_save_hdf5_group(out_scan, in_scan, scalar_repeats, skip_existing=True)
return output_root_uri
@dataclass
class _H5SoftLink:
path: str
relative: bool
@dataclass
class _H5VirtualSource:
filename: str
data_path: str
shape: Tuple[int]
dtype: DTypeLike
force_shape0: Optional[int] = None
@property
def layout_shape(self) -> Tuple[int]:
if self.force_shape0 is None:
return self.shape
return (self.force_shape0,) + self.shape[1:]
def get_layout_h5item(self, filename: str) -> h5py.VirtualSource:
ext_filename = os.path.relpath(self.filename, os.path.dirname(filename))
return h5py.VirtualSource(
ext_filename, self.data_path, shape=self.shape, dtype=self.dtype
)
@dataclass
class _H5Item:
attrs: dict
@dataclass
class _H5Group(_H5Item):
items: Dict[str, Union[_H5Item, _H5SoftLink]]
@dataclass
class _H5Dataset(_H5Item):
value: Any
@dataclass
class _H5DatasetExpression(_H5Item):
expression: str
units: str
start_var: str = "<"
end_var: str = ">"
result: Optional[Quantity] = None
@dataclass
class _H5ScalarDataset(_H5Item):
values: List[numpy.number]
@dataclass
class _VirtualDataset(_H5Item):
sources: List[_H5VirtualSource]
@property
def dtype(self) -> DTypeLike:
return self.sources[0].dtype
@property
def shape(self) -> Tuple[int]:
s0 = sum(source_scani.layout_shape[0] for source_scani in self.sources)
sother = self.sources[0].layout_shape[1:]
return (s0,) + sother
@property
def __len__(self) -> int:
return self.shape[0]
def _parse_hdf5_group(
parent: h5py.Group, _name_prefix_strip: int = 0
) -> Tuple[_H5Group, Dict[str, Union[_VirtualDataset, _H5ScalarDataset]]]:
if _name_prefix_strip <= 0:
_name_prefix_strip = len(parent.name) + 1
filename = parent.file.filename
items = dict()
group = _H5Group(attrs=dict(parent.attrs), items=items)
datasets = dict()
for name in parent:
link = parent.get(name, getlink=True)
if isinstance(link, h5py.HardLink):
child = parent.get(name)
if isinstance(child, h5py.Dataset):
if numpy.issubdtype(child.dtype, numpy.number):
if child.ndim == 0:
values = [child[()]]
dataset = _H5ScalarDataset(
attrs=dict(child.attrs), values=values
)
else:
sources = [
_H5VirtualSource(
filename=filename,
data_path=child.name,
shape=child.shape,
dtype=child.dtype,
)
]
dataset = _VirtualDataset(
attrs=dict(child.attrs), sources=sources
)
key = child.name[_name_prefix_strip:]
datasets[key] = dataset
items[name] = dataset
else:
value = child[()]
if isinstance(value, bytes):
value = value.decode()
items[name] = _H5Dataset(attrs=dict(child.attrs), value=value)
elif isinstance(child, h5py.Group):
if child.attrs.get("NX_class", "") == "NXdata":
continue
items[name], sub_H5Datasets = _parse_hdf5_group(
child, _name_prefix_strip=_name_prefix_strip
)
datasets.update(sub_H5Datasets)
else:
logger.warning(
f"ignore HDF5 item {parent.name}/{name} for concatenation ({type(child)})"
)
elif isinstance(link, h5py.SoftLink):
if link.path.startswith("/"):
items[name] = _H5SoftLink(
path=link.path[_name_prefix_strip:], relative=False
)
else:
items[name] = _H5SoftLink(path=link.path, relative=True)
else:
logger.warning(
f"ignore HDF5 link {parent.name}/{name} for concatenation ({type(link)})"
)
return group, datasets
def _append_H5Datasets(
group: h5py.Group, datasets: Dict[str, Union[_VirtualDataset, _H5ScalarDataset]]
) -> None:
filename = group.file.filename
for name, dataset in datasets.items():
child = group[name]
if isinstance(dataset, _H5ScalarDataset):
dataset.values.append(child[()])
else:
dataset.sources.append(
_H5VirtualSource(
filename=filename,
data_path=child.name,
shape=child.shape,
dtype=child.dtype,
)
)
def _trim_datasets(
datasets: Dict[str, Union[_VirtualDataset, _H5ScalarDataset]],
) -> List[int]:
dataset_names, dataset_lengths = _dataset_lengths_per_scan(datasets)
min_dataset_lengths = list()
for scan_index, dataset_lengths_scani in enumerate(dataset_lengths):
min_dataset_length_scani = min(dataset_lengths_scani)
min_dataset_lengths.append(min_dataset_length_scani)
for dataset_name in dataset_names:
vsource_scani = datasets[dataset_name].sources[scan_index]
vsource_scani.force_shape0 = min_dataset_length_scani
# Even if you have know the logic of a dataset length, for example
#
# dataset_length = (nfast+1) * nslow
#
# I would not know how to trim it so just trim to obtain an equal number of points.
return min_dataset_lengths
def _dataset_lengths_per_scan(
datasets: Dict[str, Union[_VirtualDataset, _H5ScalarDataset]],
) -> Tuple[List[str], List[List[int]]]:
"""
:returns: dataset names with shape `(ndatasets,)` and dataset lengths with shape `(nscans, ndatasets)`
"""
dataset_lengths = list()
dataset_names = list()
for name, dataset in datasets.items():
if not isinstance(dataset, _VirtualDataset):
continue
dataset_names.append(name)
if dataset_lengths:
for source_scani, dataset_lengths_scani in zip(
dataset.sources, dataset_lengths
):
dataset_lengths_scani.append(source_scani.shape[0])
else:
dataset_lengths = [
[source_scani.shape[0]] for source_scani in dataset.sources
]
return dataset_names, dataset_lengths
def _save_hdf5_group(
group: h5py.Group,
structure: _H5Group,
scalar_repeats: List[int],
skip_existing: bool = False,
_name_prefix: str = "",
) -> None:
filename = group.file.filename
group.attrs.update(structure.attrs)
if not _name_prefix:
_name_prefix = group.name
for name, item in structure.items.items():
skip_item = skip_existing and name in group
if isinstance(item, _H5Group):
if skip_item:
subgroup = group[name]
else:
subgroup = group.create_group(name)
_save_hdf5_group(
subgroup,
item,
scalar_repeats,
skip_existing=skip_existing,
_name_prefix=_name_prefix,
)
continue
if skip_item:
continue
if isinstance(item, _H5SoftLink):
if item.relative:
group[name] = h5py.SoftLink(item.path)
else:
group[name] = h5py.SoftLink(f"{_name_prefix}/{item.path}")
elif isinstance(item, _H5Dataset):
group[name] = item.value
group[name].attrs.update(item.attrs)
elif isinstance(item, _H5ScalarDataset):
if len(set(item.values)) > 1:
# No longer a scalar dataset
group[name] = numpy.repeat(item.values, scalar_repeats)
else:
group[name] = item.values[0]
group[name].attrs.update(item.attrs)
elif isinstance(item, _VirtualDataset):
layout = h5py.VirtualLayout(shape=item.shape, dtype=item.dtype)
start_index = 0
for source_scani in item.sources:
n = source_scani.layout_shape[0]
vsource = source_scani.get_layout_h5item(filename)
layout[start_index : start_index + n] = vsource[:n]
start_index += n
group.create_virtual_dataset(name, layout, fillvalue=numpy.nan)
group[name].attrs.update(item.attrs)
elif isinstance(item, _H5DatasetExpression):
if item.result is None:
continue
group[name] = item.result.magnitude
if not item.result.units.dimensionless:
group[name].attrs["units"] = units_as_str(item.result.units)
else:
logger.debug(f"ignore HDF5 item {name} for saving ({type(item)})")
def _ensure_same_shape0(
variables: Dict[str, numpy.ndarray], name_map: Dict[str, str]
) -> None:
npoints = dict()
for key, values in variables.items():
if values.ndim >= 1:
npoints[key] = len(values)
if len(set(npoints.values())) <= 1:
return
npoints_min = min(npoints.values())
for key, npoints_orig in npoints.items():
logger.warning(
"trim '%s' from %d points to %d", name_map[key], npoints_orig, npoints_min
)
variables[key] = variables[key][:npoints_min]
def _add_virtual_axes(
group: _H5Group,
virtual_axes: Dict[str, str],
axes_units: Dict[str, str],
start_var: str = "<",
end_var: str = ">",
) -> List[_H5DatasetExpression]:
instrument = group.items["instrument"]
measurement = group.items["measurement"]
positioners = instrument.items["positioners"]
datasets = list()
for motor_name, expression in virtual_axes.items():
motor_value = _H5DatasetExpression(
attrs=dict(),
expression=expression,
units=axes_units.get(motor_name, ""),
start_var=start_var,
end_var=end_var,
)
instrument.items[motor_name] = _H5Group(
attrs={"NX_class": "NXpositioner"}, items={"value": motor_value}
)
measurement.items[motor_name] = _H5SoftLink(
f"instrument/{motor_name}/value", relative=False
)
positioners.items[motor_name] = _H5SoftLink(
f"instrument/{motor_name}/value", relative=False
)
datasets.append(motor_value)
return datasets
def _resolve_virtual_axis(
item: _H5DatasetExpression, top_group: h5py.Group, axes_units: Dict[str, str]
) -> None:
def get_data(name: str) -> Tuple[str, Quantity]:
return _get_moving_values(
top_group, name, axes_units.get(name, "") if axes_units else ""
)
expression, variables, name_map = expression_variables(
item.expression, get_data, start_var=item.start_var, end_var=item.end_var
)
_ensure_same_shape0(variables, name_map)
quantity = eval_expression(expression, variables)
if item.units:
quantity = quantity.to(item.units)
item.result = quantity
def _get_moving_values(
scan: h5py.Group, name: str, default_units: str
) -> Tuple[str, Quantity]:
measurement = scan["measurement"]
if name in measurement:
full_name = f"{measurement.name}/{name}"
values = _get_value_with_units(measurement[name], default_units)
return full_name, values
instrument = scan["instrument"]
if name in instrument:
group = instrument[name]
if "value" in group:
# NXpositioner
full_name = f"{group.name}/value"
values = _get_value_with_units(group["value"], default_units)
return full_name, values
else:
# NXdetector
full_name = f"{group.name}/data"
values = _get_value_with_units(group["data"], default_units)
return full_name, values
positioners = instrument["positioners"]
if name in positioners:
full_name = f"{positioners.name}/{name}"
values = _get_value_with_units(positioners[name], default_units)
return full_name, values
raise RuntimeError(f"'{name}' is neither a detector nor a positioner")
def _get_value_with_units(dset: h5py.Dataset, default_units: str) -> Quantity:
units = dset.attrs.get("units", default_units)
ureg = unit_registry()
return dset[()] * ureg.parse_units(units)