Source code for ewoksfluo.tasks.input.pick_scan_groups
from itertools import zip_longest
from typing import List
from typing import Optional
from typing import Tuple
try:
from typing import Self
except ImportError:
from typing_extensions import Self
from ewokscore import Task
from ewokscore.model import BaseInputModel
from ewokscore.model import BaseOutputModel
from pydantic import Field
from pydantic import model_validator
from .pick_utils import pick_scans
[docs]
class Inputs(BaseInputModel):
filenames: List[List[str]] = Field(
description="Groups of Bliss dataset HDF5 file names.",
examples=[
[
["/data/dataset1.h5", "/data/dataset2.h5"],
["/data/dataset3.h5", "/data/dataset4.h5"],
]
],
)
scan_ranges: List[List[Tuple[int, int]]] = Field(
description="Ranges of scan numbers for each group.",
examples=[
[
[(1, 4), (100, 101)],
[(1, 4), (100, 101)],
]
],
)
exclude_scans: List[List[List[int]]] = Field(
default_factory=list,
description="Scan numbers to exclude for each range.",
examples=[
[
[[1, 3], [1]],
[[1, 3], [1]],
]
],
)
group_by_index: bool = Field(
False,
description="If False, preserve the input grouping. If True, regroup scans by scan index across input groups.",
)
retry_timeout: Optional[float] = Field(
20,
description="Timeout in seconds when waiting for the scan to be fully written. "
"`None` means wait forever. Negative means do not wait.",
)
retry_period: float = Field(
0.5,
description="Retry period in seconds when waiting for the scan to be fully written.",
)
[docs]
@model_validator(mode="after")
def check_and_expand_lengths(self) -> Self:
n_groups = len(self.filenames)
if len(self.scan_ranges) != n_groups:
raise ValueError("`filenames` and `scan_ranges` must have the same length")
if len(self.exclude_scans) == 0:
self.exclude_scans = [[] for _ in range(n_groups)]
if len(self.exclude_scans) != n_groups:
raise ValueError(
"`exclude_scans` must have the same length as `filenames` or be empty"
)
for i, (filenames, scan_ranges, exclude_scans) in enumerate(
zip(self.filenames, self.scan_ranges, self.exclude_scans)
):
n_files = len(filenames)
n_ranges = len(scan_ranges)
n_exclude = len(exclude_scans)
if n_files != n_ranges:
raise ValueError(
f"Group {i}: `filenames` and `scan_ranges` "
"must have the same length"
)
if n_exclude == 0:
self.exclude_scans[i] = [[] for _ in range(n_files)]
continue
if n_exclude != n_files:
raise ValueError(
f"Group {i}: `exclude_scans` must have the same "
"length as `filenames` or be empty"
)
return self
[docs]
class Outputs(BaseOutputModel):
bliss_scan_uris: List[List[str]] = Field(
description="Several lists of Bliss scan URI's. "
"Use `group_by_index=True` to get the transpose.",
examples=[
[
"/data/dataset1.h5::/2.1",
"/data/dataset2.h5::/2.1",
],
[
"/data/dataset1.h5::/3.1",
"/data/dataset2.h5::/3.1",
],
],
)
[docs]
class PickScanGroups(Task, input_model=Inputs, output_model=Outputs):
"""Select groups of Bliss scans from multiple groups of files.
Groups are not required to match in terms of number of files
or number is scans.
"""
[docs]
def run(self):
grouped_uris = []
for filenames, scan_ranges, exclude_scans in zip(
self.inputs.filenames,
self.inputs.scan_ranges,
self.inputs.exclude_scans,
):
uris = pick_scans(filenames, scan_ranges, exclude_scans)
grouped_uris.append(uris)
if self.inputs.group_by_index:
# Transpose grouped_uris
self.outputs.bliss_scan_uris = [
[uri for uri in group if uri is not None]
for group in zip_longest(*grouped_uris)
]
else:
self.outputs.bliss_scan_uris = grouped_uris