Skip to content

OCELoss

Bases: Module

Source code in cellulus/criterions/oce_loss.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class OCELoss(nn.Module):  # type: ignore
    def __init__(
        self,
        temperature: float,
        regularization_weight: float,
        density: float,
        num_spatial_dims: int,
        device: torch.device,
    ):
        """Class definition for loss.

        Parameters
        ----------

            temperature:
                Factor used to scale the gaussian function and control
                the rate of damping.

            regularization_weight:
                The weight of the L2 regularizer on the object-centric embeddings.

            density:
                Determines the fraction of patches to sample per crop,
                during training.

            num_spatial_dims:
                Should be equal to 2 for 2D and 3 for 3D.

            device:
                The device to train on.
                Set to 'cpu' to train without GPU.

        """
        super().__init__()
        self.temperature = temperature
        self.regularization_weight = regularization_weight
        self.density = density
        self.num_spatial_dims = num_spatial_dims
        self.device = device

    @staticmethod
    def distance_function(embedding_0, embedding_1):
        difference = embedding_0 - embedding_1
        return difference.norm(2, dim=-1)

    def non_linearity(self, distance):
        return 1 - (-distance.pow(2) / self.temperature).exp()

    def forward(self, anchor_embedding, reference_embedding):
        distance = self.distance_function(
            anchor_embedding, reference_embedding.detach()
        )
        non_linear_distance = self.non_linearity(distance)
        oce_loss = non_linear_distance.sum()
        regularization_loss = (
            self.regularization_weight * anchor_embedding.norm(2, dim=-1).sum()
        )
        loss = oce_loss + regularization_loss
        return loss, oce_loss, regularization_loss

__init__(temperature, regularization_weight, density, num_spatial_dims, device)

Class definition for loss.

Parameters

temperature:
    Factor used to scale the gaussian function and control
    the rate of damping.

regularization_weight:
    The weight of the L2 regularizer on the object-centric embeddings.

density:
    Determines the fraction of patches to sample per crop,
    during training.

num_spatial_dims:
    Should be equal to 2 for 2D and 3 for 3D.

device:
    The device to train on.
    Set to 'cpu' to train without GPU.
Source code in cellulus/criterions/oce_loss.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def __init__(
    self,
    temperature: float,
    regularization_weight: float,
    density: float,
    num_spatial_dims: int,
    device: torch.device,
):
    """Class definition for loss.

    Parameters
    ----------

        temperature:
            Factor used to scale the gaussian function and control
            the rate of damping.

        regularization_weight:
            The weight of the L2 regularizer on the object-centric embeddings.

        density:
            Determines the fraction of patches to sample per crop,
            during training.

        num_spatial_dims:
            Should be equal to 2 for 2D and 3 for 3D.

        device:
            The device to train on.
            Set to 'cpu' to train without GPU.

    """
    super().__init__()
    self.temperature = temperature
    self.regularization_weight = regularization_weight
    self.density = density
    self.num_spatial_dims = num_spatial_dims
    self.device = device