Skip to content

TrainConfig

Train configuration.

Parameters

crop_size (default = [252, 252]):

    The size of the crops - specified as a list of number of pixels -
    extracted from the raw images, used during training.

batch_size (default = 8):

    The number of samples to use per batch.

max_iterations (default = 100000):

    The maximum number of iterations to train for.

initial_learning_rate (default = 4e-5):

    Initial learning rate of the optimizer.

temperature (default = 10):

    Factor used to scale the gaussian function and control the rate of damping.

regularizer_weight (default = 1e-5):

    The weight of the L2 regularizer on the object-centric embeddings.

reduce_mean (default = True):

    If True, the loss contribution is averaged across all pairs of patches.

density (default = 0.1)

    Determines the fraction of patches to sample per crop, during training.

kappa (default = 10.0):

    Neighborhood radius to extract patches from.

save_model_every (default = 1000):

    The model weights are saved every few iterations.

save_best_model_every (default = 100):

    The best loss is evaluated every few iterations.

save_snapshot_every (default = 1000):

    The zarr snapshot is saved every few iterations.

num_workers (default = 8):

    The number of sub-processes to use for data-loading.

elastic_deform (default = True):

    If set to True, the data is elastically deformed
    in order to increase training samples.

control_point_spacing (default = 64):

    The distance in pixels between control points used for elastic
    deformation of the raw data during training.
    Only used if `elastic_deform` is set to True.

control_point_jitter (default = 2.0):

    How much to jitter the control points for elastic deformation
    of the raw data during training, given as the standard deviation of
    a normal distribution with zero mean.
    Only used if `elastic_deform` is set to True.

train_data_config:

    Configuration object for the training data.

validate_data_config (default = None):

    Configuration object for the validation data.

device (default = 'cuda:0'):

    The device to train on.
    Set to 'cpu' to train without GPU.
Source code in cellulus/configs/train_config.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
@attrs.define
class TrainConfig:
    """Train configuration.

    Parameters
    ----------

        crop_size (default = [252, 252]):

            The size of the crops - specified as a list of number of pixels -
            extracted from the raw images, used during training.

        batch_size (default = 8):

            The number of samples to use per batch.

        max_iterations (default = 100000):

            The maximum number of iterations to train for.

        initial_learning_rate (default = 4e-5):

            Initial learning rate of the optimizer.

        temperature (default = 10):

            Factor used to scale the gaussian function and control the rate of damping.

        regularizer_weight (default = 1e-5):

            The weight of the L2 regularizer on the object-centric embeddings.

        reduce_mean (default = True):

            If True, the loss contribution is averaged across all pairs of patches.

        density (default = 0.1)

            Determines the fraction of patches to sample per crop, during training.

        kappa (default = 10.0):

            Neighborhood radius to extract patches from.

        save_model_every (default = 1000):

            The model weights are saved every few iterations.

        save_best_model_every (default = 100):

            The best loss is evaluated every few iterations.

        save_snapshot_every (default = 1000):

            The zarr snapshot is saved every few iterations.

        num_workers (default = 8):

            The number of sub-processes to use for data-loading.

        elastic_deform (default = True):

            If set to True, the data is elastically deformed
            in order to increase training samples.

        control_point_spacing (default = 64):

            The distance in pixels between control points used for elastic
            deformation of the raw data during training.
            Only used if `elastic_deform` is set to True.

        control_point_jitter (default = 2.0):

            How much to jitter the control points for elastic deformation
            of the raw data during training, given as the standard deviation of
            a normal distribution with zero mean.
            Only used if `elastic_deform` is set to True.

        train_data_config:

            Configuration object for the training data.

        validate_data_config (default = None):

            Configuration object for the validation data.

        device (default = 'cuda:0'):

            The device to train on.
            Set to 'cpu' to train without GPU.


    """

    train_data_config: DatasetConfig = attrs.field(
        default=None, converter=to_config(DatasetConfig)
    )
    validate_data_config: DatasetConfig = attrs.field(
        default=None, converter=to_config(DatasetConfig)
    )
    crop_size: List = attrs.field(default=[252, 252], validator=instance_of(List))
    batch_size: int = attrs.field(default=8, validator=instance_of(int))
    max_iterations: int = attrs.field(default=100_000, validator=instance_of(int))
    initial_learning_rate: float = attrs.field(
        default=4e-5, validator=instance_of(float)
    )
    density: float = attrs.field(default=0.1, validator=instance_of(float))
    kappa: float = attrs.field(default=10.0, validator=instance_of(float))
    temperature: float = attrs.field(default=10.0, validator=instance_of(float))
    regularizer_weight: float = attrs.field(default=1e-5, validator=instance_of(float))
    save_model_every: int = attrs.field(default=1_000, validator=instance_of(int))
    save_best_model_every: int = attrs.field(default=100, validator=instance_of(int))
    save_snapshot_every: int = attrs.field(default=1_000, validator=instance_of(int))
    num_workers: int = attrs.field(default=8, validator=instance_of(int))
    elastic_deform: bool = attrs.field(default=True, validator=instance_of(bool))
    control_point_spacing: int = attrs.field(default=64, validator=instance_of(int))
    control_point_jitter: float = attrs.field(default=2.0, validator=instance_of(float))
    device: str = attrs.field(default="cuda:0", validator=instance_of(str))