Source code for mojito.utils

"""
Utilities
=========

This module contains utility functions for various tasks.

.. autofunction:: print_hdf5_structure

.. autofunction:: assert_datasets_almost_equal

"""

import logging
from pathlib import Path
from typing import Any, Callable, Literal, Sequence, TypeVar, overload

import numpy as np
from h5py import Dataset, File, Group

logger = logging.getLogger(__name__)


def _validate_require(require: Literal["all", "one"] | None) -> None:
    """Validate that the require parameter is one of the allowed values."""
    if require not in ("all", "one", None):
        raise ValueError("require must be 'all', 'one', or None")






T = TypeVar("T")


def _get_datasets(
    name: str,
    groups: Sequence[Group],
    *,
    require: Literal["all", "one"] | None = "all",
) -> list[Dataset]:
    """Get datasets with the given name from the list of groups.

    Parameters
    ----------
    name
        Name of the dataset to get from each group.
    groups
        List of HDF5 groups to get the datasets from.
    require
        Whether to require the dataset to be present in all groups or at least
        one group. If "all", raises an error if any group does not contain a
        dataset with the given name. If "one", raises an error if no groups
        contain a dataset with the given name. If None, does not raise an error
        if any group does not contain a dataset with the given name, and simply
        ignores that group.

    Returns
    -------
        List of datasets with the given name from each group.

    Raises
    ------
    KeyError
        If any group does not contain a dataset with the given name and require
        is "all", or if no groups contain a dataset with the given name and
        require is "one".
    TypeError
        If the dataset is not an instance of :class:`h5py.Dataset`.
    ValueError
        If require is not "all", "one", or None.
    """
    _validate_require(require)

    datasets: list[Dataset] = []
    for group in groups:
        if name not in group:
            if require == "all":
                raise KeyError(f"Group {group} does not contain dataset '{name}'")
            logger.info("Group %s does not contain dataset '%s'", group, name)
            continue
        dataset = group[name]
        if not isinstance(dataset, Dataset):
            raise TypeError(
                f"Expected dataset '{group}/{name}', but found {type(dataset)}"
            )
        datasets.append(dataset)
    if require == "one" and not datasets:
        raise KeyError(f"No groups contain dataset '{name}'")
    return datasets


def _get_groups(
    name: str,
    groups: Sequence[Group],
    *,
    require: Literal["all", "one"] | None = "all",
) -> list[Group]:
    """Get groups with the given name from the list of groups.

    Parameters
    ----------
    name
        Name of the group to get from each group.
    groups
        List of HDF5 groups to get the groups from.
    require
        Whether to require the group to be present in all groups or at least one
        group. If "all", raises an error if any group does not contain a group
        with the given name. If "one", raises an error if no groups contain a
        group with the given name. If None, does not raise an error if any group
        does not contain a group with the given name, and simply ignores that
        group.

    Returns
    -------
        List of groups with the given name from each group.

    Raises
    ------
    KeyError
        If any group does not contain a group with the given name and require is
        "all", or if no groups contain a group with the given name and require
        is "one".
    TypeError
        If the group is not an instance of :class:`h5py.Group`.
    ValueError
        If require is not "all", "one", or None.
    """
    _validate_require(require)

    groups_with_name: list[Group] = []
    for group in groups:
        if name not in group:
            if require == "all":
                raise KeyError(f"Group {group} does not contain group '{name}'")
            logger.info("Group %s does not contain group '%s'", group, name)
            continue
        subgroup = group[name]
        if not isinstance(subgroup, Group):
            raise TypeError(
                f"Expected group '{group}/{name}', but found {type(subgroup)}"
            )
        groups_with_name.append(subgroup)
    if require == "one" and not groups_with_name:
        raise KeyError(f"No groups contain group '{name}'")
    return groups_with_name


@overload
def _get_attrs(
    name: str,
    groups: Sequence[Group],
    *,
    require: Literal["all", "one"] | None = "all",
    cast: None = None,
) -> list[Any]: ...


@overload
def _get_attrs(
    name: str,
    groups: Sequence[Group],
    *,
    require: Literal["all", "one"] | None = "all",
    cast: Callable[[Any], T],
) -> list[T]: ...


def _get_attrs(
    name: str,
    groups: Sequence[Group],
    *,
    require: Literal["all", "one"] | None = "all",
    cast: Callable[[Any], Any] | None = None,
) -> list[Any]:
    """Get attributes with the given name from the list of groups.

    Parameters
    ----------
    name
        Name of the attribute to get from each group.
    groups
        List of HDF5 groups to get the attributes from.
    require
        Whether to require the attribute to be present in all groups or at least
        one group. If "all", raises an error if any group does not contain an
        attribute with the given name. If "one", raises an error if no groups
        contain an attribute with the given name. If None, does not raise an
        error if any group does not contain an attribute with the given name,
        and simply ignores that group.
    cast
        Optional function to cast the attribute value to a different type. If
        provided, the function is applied to the attribute value before
        returning it.

    Returns
    -------
        List of attribute values with the given name from each group.

    Raises
    ------
    KeyError
        If any group does not contain an attribute with the given name and
        require is "all", or if no groups contain an attribute with the given
        name and require is "one".
    ValueError
        If require is not "all", "one", or None.
    TypeError
        If the cast function is provided and fails to cast the attribute value.
    """
    _validate_require(require)

    attrs: list[Any] = []
    for group in groups:
        if name not in group.attrs:
            if require == "all":
                raise KeyError(f"Group {group} does not contain attribute '{name}'")
            logger.info("Group %s does not contain attribute '%s'", group, name)
            continue
        attr = group.attrs[name]
        if cast is not None:
            try:
                attr = cast(attr)
            except Exception as e:
                raise TypeError(
                    f"Failed to cast attribute '{group}/{name}' with value "
                    f"{attr} using {cast}: {e}"
                ) from e
        attrs.append(attr)
    if require == "one" and not attrs:
        raise KeyError(f"No groups contain attribute '{name}'")
    return attrs


[docs] def assert_datasets_almost_equal( datasets: Sequence[Dataset], *, chunk: int = 100_000, ) -> None: """Assert that the given datasets are almost equal, reading them in chunks. This is a helper function to check that datasets are identical across groups without loading them entirely into memory. Note that chunking is only done along the first dimension of the datasets, so this function is most effective for datasets where the first dimension is the largest. Parameters ---------- datasets Sequence of datasets to compare. All datasets must have the same shape. chunk Chunk size to use on the first axis when comparing the datasets. This limits memory usage when comparing large datasets. Raises ------ ValueError If the chunk size is not positive. AssertionError If the datasets have different shapes. AssertionError If the datasets are not almost equal within the given chunk size. """ # Check chunk size is positive if chunk <= 0: raise ValueError("Chunk size must be positive") # Return if there are no datasets to compare if len(datasets) < 2: return # Take first dataset as reference reference = datasets[0] # Check that all datasets have the same shape reference_shape = reference.shape for dataset in datasets[1:]: if dataset.shape != reference_shape: raise AssertionError("Datasets have different shapes") # For 0-D (scalar) datasets, compare directly without chunking if len(reference_shape) == 0: for dataset in datasets[1:]: np.testing.assert_array_almost_equal( dataset[()], reference[()], err_msg="Inconsistent datasets across files", ) return # For higher-dimensional datasets, compare in chunks along the first axis # and raise a warning if the number of dimensions is greater than 1 if len(reference_shape) > 1: logger.warning( "Comparing datasets with more than 1 dimension using chunking " "along the first axis. This may be memory-intensive if the " "datasets dare large along other axes." ) # Check that all datasets are almost equal in chunks for i in range(0, reference_shape[0], chunk): reference_chunk = reference[i : i + chunk] for dataset in datasets[1:]: np.testing.assert_array_almost_equal( dataset[i : i + chunk], reference_chunk, err_msg="Inconsistent datasets across files", )