"""
Utilities
=========
This module contains utility functions for various tasks.
.. autofunction:: print_hdf5_structure
.. autofunction:: assert_datasets_almost_equal
"""
import logging
from pathlib import Path
from typing import Any, Callable, Literal, Sequence, TypeVar, overload
import numpy as np
from h5py import Dataset, File, Group
logger = logging.getLogger(__name__)
def _validate_require(require: Literal["all", "one"] | None) -> None:
"""Validate that the require parameter is one of the allowed values."""
if require not in ("all", "one", None):
raise ValueError("require must be 'all', 'one', or None")
[docs]
def print_hdf5_structure(path: str | Path, *, print_attrs: bool = False) -> None:
"""Print the structure of an HDF5 file with visual hierarchy.
Recursively prints the hierarchy of groups and datasets in an HDF5 file,
displaying their names, types, and shapes. Optionally prints attributes for
groups and datasets.
Parameters
----------
path
Path to the HDF5 file.
print_attrs
Whether to print attributes for groups and datasets.
Example
-------
>>> print_hdf5_structure("data.h5", print_attrs=True)
HDF5 Structure of data.h5
└── [GROUP] /
├── @version: 1.0
├── [DATASET] /dataset1
│ dtype: float64, shape: (100,)
└── [GROUP] /group1
└── [DATASET] /group1/dataset2
dtype: int32, shape: (50, 50)
@units: meters
"""
def print_structure(
obj: Group | Dataset,
prefix: str = "",
is_last: bool = True,
) -> None:
"""Recursively print HDF5 structure with tree-like visualization."""
current_prefix = "└── " if is_last else "├── "
next_prefix = " " if is_last else "│ "
if isinstance(obj, Group):
print(f"{prefix}{current_prefix}[GROUP] {obj.name}")
if print_attrs and obj.attrs:
attr_keys = list(obj.attrs.keys())
for i, key in enumerate(attr_keys):
is_last_attr = i == len(attr_keys) - 1
attr_prefix = (
prefix + next_prefix + ("└── " if is_last_attr else "├── ")
)
print(f"{attr_prefix}@{key}: {obj.attrs[key]}")
keys = list(obj.keys())
for i, key in enumerate(keys):
next_obj = obj[key]
if isinstance(next_obj, (Group, Dataset)):
is_last_child = i == len(keys) - 1
print_structure(next_obj, prefix + next_prefix, is_last_child)
elif isinstance(obj, Dataset):
name = obj.name
print(f"{prefix}{current_prefix}[DATASET] {name}")
print(f"{prefix}{next_prefix} " f"dtype: {obj.dtype}, shape: {obj.shape}")
if print_attrs and obj.attrs:
attr_keys = list(obj.attrs.keys())
for i, key in enumerate(attr_keys):
is_last_attr = i == len(attr_keys) - 1
attr_prefix = (
prefix + next_prefix + ("└── " if is_last_attr else "├── ")
)
print(f"{attr_prefix}@{key}: {obj.attrs[key]}")
with File(str(path), "r") as f:
print(f"HDF5 Structure of {path}")
print_structure(f)
T = TypeVar("T")
def _get_datasets(
name: str,
groups: Sequence[Group],
*,
require: Literal["all", "one"] | None = "all",
) -> list[Dataset]:
"""Get datasets with the given name from the list of groups.
Parameters
----------
name
Name of the dataset to get from each group.
groups
List of HDF5 groups to get the datasets from.
require
Whether to require the dataset to be present in all groups or at least
one group. If "all", raises an error if any group does not contain a
dataset with the given name. If "one", raises an error if no groups
contain a dataset with the given name. If None, does not raise an error
if any group does not contain a dataset with the given name, and simply
ignores that group.
Returns
-------
List of datasets with the given name from each group.
Raises
------
KeyError
If any group does not contain a dataset with the given name and require
is "all", or if no groups contain a dataset with the given name and
require is "one".
TypeError
If the dataset is not an instance of :class:`h5py.Dataset`.
ValueError
If require is not "all", "one", or None.
"""
_validate_require(require)
datasets: list[Dataset] = []
for group in groups:
if name not in group:
if require == "all":
raise KeyError(f"Group {group} does not contain dataset '{name}'")
logger.info("Group %s does not contain dataset '%s'", group, name)
continue
dataset = group[name]
if not isinstance(dataset, Dataset):
raise TypeError(
f"Expected dataset '{group}/{name}', but found {type(dataset)}"
)
datasets.append(dataset)
if require == "one" and not datasets:
raise KeyError(f"No groups contain dataset '{name}'")
return datasets
def _get_groups(
name: str,
groups: Sequence[Group],
*,
require: Literal["all", "one"] | None = "all",
) -> list[Group]:
"""Get groups with the given name from the list of groups.
Parameters
----------
name
Name of the group to get from each group.
groups
List of HDF5 groups to get the groups from.
require
Whether to require the group to be present in all groups or at least one
group. If "all", raises an error if any group does not contain a group
with the given name. If "one", raises an error if no groups contain a
group with the given name. If None, does not raise an error if any group
does not contain a group with the given name, and simply ignores that
group.
Returns
-------
List of groups with the given name from each group.
Raises
------
KeyError
If any group does not contain a group with the given name and require is
"all", or if no groups contain a group with the given name and require
is "one".
TypeError
If the group is not an instance of :class:`h5py.Group`.
ValueError
If require is not "all", "one", or None.
"""
_validate_require(require)
groups_with_name: list[Group] = []
for group in groups:
if name not in group:
if require == "all":
raise KeyError(f"Group {group} does not contain group '{name}'")
logger.info("Group %s does not contain group '%s'", group, name)
continue
subgroup = group[name]
if not isinstance(subgroup, Group):
raise TypeError(
f"Expected group '{group}/{name}', but found {type(subgroup)}"
)
groups_with_name.append(subgroup)
if require == "one" and not groups_with_name:
raise KeyError(f"No groups contain group '{name}'")
return groups_with_name
@overload
def _get_attrs(
name: str,
groups: Sequence[Group],
*,
require: Literal["all", "one"] | None = "all",
cast: None = None,
) -> list[Any]: ...
@overload
def _get_attrs(
name: str,
groups: Sequence[Group],
*,
require: Literal["all", "one"] | None = "all",
cast: Callable[[Any], T],
) -> list[T]: ...
def _get_attrs(
name: str,
groups: Sequence[Group],
*,
require: Literal["all", "one"] | None = "all",
cast: Callable[[Any], Any] | None = None,
) -> list[Any]:
"""Get attributes with the given name from the list of groups.
Parameters
----------
name
Name of the attribute to get from each group.
groups
List of HDF5 groups to get the attributes from.
require
Whether to require the attribute to be present in all groups or at least
one group. If "all", raises an error if any group does not contain an
attribute with the given name. If "one", raises an error if no groups
contain an attribute with the given name. If None, does not raise an
error if any group does not contain an attribute with the given name,
and simply ignores that group.
cast
Optional function to cast the attribute value to a different type. If
provided, the function is applied to the attribute value before
returning it.
Returns
-------
List of attribute values with the given name from each group.
Raises
------
KeyError
If any group does not contain an attribute with the given name and
require is "all", or if no groups contain an attribute with the given
name and require is "one".
ValueError
If require is not "all", "one", or None.
TypeError
If the cast function is provided and fails to cast the attribute value.
"""
_validate_require(require)
attrs: list[Any] = []
for group in groups:
if name not in group.attrs:
if require == "all":
raise KeyError(f"Group {group} does not contain attribute '{name}'")
logger.info("Group %s does not contain attribute '%s'", group, name)
continue
attr = group.attrs[name]
if cast is not None:
try:
attr = cast(attr)
except Exception as e:
raise TypeError(
f"Failed to cast attribute '{group}/{name}' with value "
f"{attr} using {cast}: {e}"
) from e
attrs.append(attr)
if require == "one" and not attrs:
raise KeyError(f"No groups contain attribute '{name}'")
return attrs
[docs]
def assert_datasets_almost_equal(
datasets: Sequence[Dataset],
*,
chunk: int = 100_000,
) -> None:
"""Assert that the given datasets are almost equal, reading them in chunks.
This is a helper function to check that datasets are identical across groups
without loading them entirely into memory.
Note that chunking is only done along the first dimension of the datasets,
so this function is most effective for datasets where the first dimension is
the largest.
Parameters
----------
datasets
Sequence of datasets to compare. All datasets must have the same shape.
chunk
Chunk size to use on the first axis when comparing the datasets. This
limits memory usage when comparing large datasets.
Raises
------
ValueError
If the chunk size is not positive.
AssertionError
If the datasets have different shapes.
AssertionError
If the datasets are not almost equal within the given chunk size.
"""
# Check chunk size is positive
if chunk <= 0:
raise ValueError("Chunk size must be positive")
# Return if there are no datasets to compare
if len(datasets) < 2:
return
# Take first dataset as reference
reference = datasets[0]
# Check that all datasets have the same shape
reference_shape = reference.shape
for dataset in datasets[1:]:
if dataset.shape != reference_shape:
raise AssertionError("Datasets have different shapes")
# For 0-D (scalar) datasets, compare directly without chunking
if len(reference_shape) == 0:
for dataset in datasets[1:]:
np.testing.assert_array_almost_equal(
dataset[()],
reference[()],
err_msg="Inconsistent datasets across files",
)
return
# For higher-dimensional datasets, compare in chunks along the first axis
# and raise a warning if the number of dimensions is greater than 1
if len(reference_shape) > 1:
logger.warning(
"Comparing datasets with more than 1 dimension using chunking "
"along the first axis. This may be memory-intensive if the "
"datasets dare large along other axes."
)
# Check that all datasets are almost equal in chunks
for i in range(0, reference_shape[0], chunk):
reference_chunk = reference[i : i + chunk]
for dataset in datasets[1:]:
np.testing.assert_array_almost_equal(
dataset[i : i + chunk],
reference_chunk,
err_msg="Inconsistent datasets across files",
)