Skip to content

DatasetMetaData

Source code in cellulus/datasets/meta_data.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class DatasetMetaData:
    def __init__(self, shape, axis_names):
        self.num_dims = len(axis_names)
        self.num_spatial_dims: int = 0
        self.num_samples: int = 0
        self.num_channels: int = 0
        self.sample_dim = None
        self.channel_dim = None
        self.time_dim = None
        self.spatial_array: Tuple[int, ...] = ()
        for dim, axis_name in enumerate(axis_names):
            if axis_name == "s":
                self.sample_dim = dim
                self.num_samples = shape[dim]
            elif axis_name == "c":
                self.channel_dim = dim
                self.num_channels = shape[dim]
            elif axis_name == "t":
                self.num_spatial_dims += 1
                self.time_dim = dim
            elif axis_name == "z":
                self.num_spatial_dims += 1
                self.spatial_array += (shape[dim],)
            elif axis_name == "y":
                self.num_spatial_dims += 1
                self.spatial_array += (shape[dim],)
            elif axis_name == "x":
                self.num_spatial_dims += 1
                self.spatial_array += (shape[dim],)

        if self.sample_dim is None:
            self.__invalid_dataset("dataset does not have a sample dimension")

        if self.channel_dim is None:
            self.__invalid_dataset("dataset does not have a channel dimension")

        if self.num_dims != len(shape):
            self.__invalid_dataset(
                f"dataset has {len(shape)} dimensions, but attribute "
                f"axis_names has {self.num_dims} entries"
            )

    @staticmethod
    def from_dataset_config(dataset_config: DatasetConfig) -> "DatasetMetaData":
        container = zarr.open(dataset_config.container_path, "r")
        try:
            data = container[dataset_config.dataset_name]
        except KeyError:
            DatasetMetaData.__invalid_dataset(
                f"Zarr container {dataset_config.container_path} does not contain "
                f'"{dataset_config.dataset_name}" dataset'
            )

        try:
            axis_names = data.attrs["axis_names"]
        except KeyError:
            DatasetMetaData.__invalid_dataset(
                f'"{dataset_config.dataset_name}" dataset in '
                f'{dataset_config.container_path} does not contain "axis_names" '
                "attribute"
            )

        try:
            return DatasetMetaData(data.shape, axis_names)
        except RuntimeError as e:
            raise RuntimeError(
                f'"{dataset_config.dataset_name}" dataset in '
                f"{dataset_config.container_path} has invalid meta-data"
            ) from e

    @staticmethod
    def __invalid_dataset(message):
        raise RuntimeError(
            message
            + "\n\n"
            + (
                "The raw dataset should have shape "
                "(s, c, [t,] [z,] y, x), where s = # of samples, c = # of channels, "
                "t = # of frames, and z/y/x are spatial extents. The dataset should "
                'have an "axis_names" attribute that contains the names of the used '
                'axes, e.g., ["s", "c", "y", "x"] for a 2D dataset.'
            )
        )