ptame.models.components.loss

  1import abc
  2from dataclasses import dataclass
  3from typing import ClassVar
  4
  5import torch
  6from torch.nn.functional import cross_entropy
  7
  8
  9@dataclass
 10class Loss(abc.ABC):
 11    """Loss is an abstract class that represents a loss function for the TAME-
 12    based family of explainabiltiy methods."""
 13
 14    num_terms: ClassVar[int]
 15
 16    @abc.abstractmethod
 17    def __call__(
 18        self,
 19        logits: torch.Tensor,
 20        logits_masked: torch.Tensor,
 21        targets: torch.Tensor,
 22        masks: torch.Tensor,
 23        epoch: int,
 24    ) -> list[torch.Tensor]:
 25        """Calculates the loss based on the logits, targets, and masks.
 26
 27        Args:
 28            logits (torch.Tensor): The predicted logits.
 29            logits_masked (torch.Tensor): The predicted logits with the masks applied.
 30            targets (torch.Tensor): The targets.
 31            masks (torch.Tensor): The masks tensor.
 32
 33        Returns:
 34            List[torch.Tensor]: A list of calculated losses.
 35        """
 36
 37    def area_loss(self, masks):
 38        """Calculates the area loss based on the masks.
 39
 40        Args:
 41            masks (torch.Tensor): The masks tensor.
 42
 43        Returns:
 44            torch.Tensor: The calculated area loss.
 45        """
 46        if self.area_loss_power != 1:
 47            # add e to prevent nan (derivative of sqrt at 0 is inf)
 48            masks = (masks + 0.0005) ** self.area_loss_power
 49        return torch.mean(masks)
 50
 51    def smoothness_loss(self, masks):
 52        """Calculates the smoothness loss based on the masks.
 53
 54        Args:
 55            masks (torch.Tensor): The masks tensor.
 56
 57        Returns:
 58            torch.Tensor: The calculated smoothness loss.
 59        """
 60        B, _, _, _ = masks.size()
 61        border_penalty = self.smoothness_border_penalty
 62        power = self.smoothness_power
 63        x_loss = torch.sum(
 64            (torch.abs(masks[:, :, 1:, :] - masks[:, :, :-1, :])) ** power
 65        )
 66        y_loss = torch.sum(
 67            (torch.abs(masks[:, :, :, 1:] - masks[:, :, :, :-1])) ** power
 68        )
 69        if border_penalty > 0:
 70            border = float(border_penalty) * torch.sum(
 71                masks[:, :, -1, :] ** power
 72                + masks[:, :, 0, :] ** power
 73                + masks[:, :, :, -1] ** power
 74                + masks[:, :, :, 0] ** power
 75            )
 76        else:
 77            border = 0.0
 78        return (x_loss + y_loss + border) / float(power * B)
 79
 80
 81@dataclass
 82class ClassicLoss(Loss):
 83    """ClassicLoss is a class that represents the loss function used in TAME.
 84
 85    It calculates the loss based on cross-entropy, area loss, and smoothness
 86    loss.
 87
 88    :param ce_coeff: The coefficient for the cross-entropy loss.
 89    :param area_coeff: The coefficient for the area loss.
 90    :param smoothness_coeff: The coefficient for the smoothness loss.
 91    :param smoothness_power: The power for the smoothness loss.
 92    :param smoothness_border_penalty: The penalty for the smoothness loss at
 93        the border.
 94    :param area_loss_power: The power for the area loss.
 95    :param num_terms: The number of terms returned by the loss function.
 96    """
 97
 98    ce_coeff: float = 1.5
 99    area_coeff: float = 2
100    smoothness_coeff: float = 0.01
101    smoothness_power: float = 2
102    smoothness_border_penalty: float = 0.3
103    area_loss_power: float = 0.3
104    num_terms: ClassVar[int] = 4
105
106    def __call__(
107        self,
108        logits_masked: torch.Tensor,
109        targets: torch.Tensor,
110        masks: torch.Tensor,
111        **kwargs,
112    ) -> list[torch.Tensor]:
113        """Calculates the overall loss based on the logits, targets, and masks.
114
115        Args:
116            logits (torch.Tensor): The predicted logits.
117            logits_masked (torch.Tensor): The predicted logits with the masks applied.
118            targets (torch.Tensor): The targets.
119            masks (torch.Tensor): The masks tensor.
120
121        Returns:
122            List[torch.Tensor]: A list of calculated losses, including the overall loss,
123            cross-entropy loss, area loss, and smoothness loss.
124        """
125        targets = targets.long()
126        variation_loss = torch.tensor(0)
127        area_loss = torch.tensor(0)
128        ce_loss = torch.tensor(0)
129        if self.smoothness_coeff > 0:
130            variation_loss = self.smoothness_loss(masks)
131        if self.area_coeff > 0:
132            area_loss = self.area_loss(masks)
133        if self.ce_coeff > 0:
134            ce_loss = cross_entropy(logits_masked, targets)
135        ce_loss = cross_entropy(logits_masked, targets)
136
137        loss = (
138            self.ce_coeff * ce_loss
139            + self.area_coeff * area_loss
140            + self.smoothness_coeff * variation_loss
141        )
142
143        return [loss, ce_loss, area_loss, variation_loss]
@dataclass
class Loss(abc.ABC):
10@dataclass
11class Loss(abc.ABC):
12    """Loss is an abstract class that represents a loss function for the TAME-
13    based family of explainabiltiy methods."""
14
15    num_terms: ClassVar[int]
16
17    @abc.abstractmethod
18    def __call__(
19        self,
20        logits: torch.Tensor,
21        logits_masked: torch.Tensor,
22        targets: torch.Tensor,
23        masks: torch.Tensor,
24        epoch: int,
25    ) -> list[torch.Tensor]:
26        """Calculates the loss based on the logits, targets, and masks.
27
28        Args:
29            logits (torch.Tensor): The predicted logits.
30            logits_masked (torch.Tensor): The predicted logits with the masks applied.
31            targets (torch.Tensor): The targets.
32            masks (torch.Tensor): The masks tensor.
33
34        Returns:
35            List[torch.Tensor]: A list of calculated losses.
36        """
37
38    def area_loss(self, masks):
39        """Calculates the area loss based on the masks.
40
41        Args:
42            masks (torch.Tensor): The masks tensor.
43
44        Returns:
45            torch.Tensor: The calculated area loss.
46        """
47        if self.area_loss_power != 1:
48            # add e to prevent nan (derivative of sqrt at 0 is inf)
49            masks = (masks + 0.0005) ** self.area_loss_power
50        return torch.mean(masks)
51
52    def smoothness_loss(self, masks):
53        """Calculates the smoothness loss based on the masks.
54
55        Args:
56            masks (torch.Tensor): The masks tensor.
57
58        Returns:
59            torch.Tensor: The calculated smoothness loss.
60        """
61        B, _, _, _ = masks.size()
62        border_penalty = self.smoothness_border_penalty
63        power = self.smoothness_power
64        x_loss = torch.sum(
65            (torch.abs(masks[:, :, 1:, :] - masks[:, :, :-1, :])) ** power
66        )
67        y_loss = torch.sum(
68            (torch.abs(masks[:, :, :, 1:] - masks[:, :, :, :-1])) ** power
69        )
70        if border_penalty > 0:
71            border = float(border_penalty) * torch.sum(
72                masks[:, :, -1, :] ** power
73                + masks[:, :, 0, :] ** power
74                + masks[:, :, :, -1] ** power
75                + masks[:, :, :, 0] ** power
76            )
77        else:
78            border = 0.0
79        return (x_loss + y_loss + border) / float(power * B)

Loss is an abstract class that represents a loss function for the TAME- based family of explainabiltiy methods.

num_terms: ClassVar[int]
def area_loss(self, masks):
38    def area_loss(self, masks):
39        """Calculates the area loss based on the masks.
40
41        Args:
42            masks (torch.Tensor): The masks tensor.
43
44        Returns:
45            torch.Tensor: The calculated area loss.
46        """
47        if self.area_loss_power != 1:
48            # add e to prevent nan (derivative of sqrt at 0 is inf)
49            masks = (masks + 0.0005) ** self.area_loss_power
50        return torch.mean(masks)

Calculates the area loss based on the masks.

Args: masks (torch.Tensor): The masks tensor.

Returns: torch.Tensor: The calculated area loss.

def smoothness_loss(self, masks):
52    def smoothness_loss(self, masks):
53        """Calculates the smoothness loss based on the masks.
54
55        Args:
56            masks (torch.Tensor): The masks tensor.
57
58        Returns:
59            torch.Tensor: The calculated smoothness loss.
60        """
61        B, _, _, _ = masks.size()
62        border_penalty = self.smoothness_border_penalty
63        power = self.smoothness_power
64        x_loss = torch.sum(
65            (torch.abs(masks[:, :, 1:, :] - masks[:, :, :-1, :])) ** power
66        )
67        y_loss = torch.sum(
68            (torch.abs(masks[:, :, :, 1:] - masks[:, :, :, :-1])) ** power
69        )
70        if border_penalty > 0:
71            border = float(border_penalty) * torch.sum(
72                masks[:, :, -1, :] ** power
73                + masks[:, :, 0, :] ** power
74                + masks[:, :, :, -1] ** power
75                + masks[:, :, :, 0] ** power
76            )
77        else:
78            border = 0.0
79        return (x_loss + y_loss + border) / float(power * B)

Calculates the smoothness loss based on the masks.

Args: masks (torch.Tensor): The masks tensor.

Returns: torch.Tensor: The calculated smoothness loss.

@dataclass
class ClassicLoss(Loss):
 82@dataclass
 83class ClassicLoss(Loss):
 84    """ClassicLoss is a class that represents the loss function used in TAME.
 85
 86    It calculates the loss based on cross-entropy, area loss, and smoothness
 87    loss.
 88
 89    :param ce_coeff: The coefficient for the cross-entropy loss.
 90    :param area_coeff: The coefficient for the area loss.
 91    :param smoothness_coeff: The coefficient for the smoothness loss.
 92    :param smoothness_power: The power for the smoothness loss.
 93    :param smoothness_border_penalty: The penalty for the smoothness loss at
 94        the border.
 95    :param area_loss_power: The power for the area loss.
 96    :param num_terms: The number of terms returned by the loss function.
 97    """
 98
 99    ce_coeff: float = 1.5
100    area_coeff: float = 2
101    smoothness_coeff: float = 0.01
102    smoothness_power: float = 2
103    smoothness_border_penalty: float = 0.3
104    area_loss_power: float = 0.3
105    num_terms: ClassVar[int] = 4
106
107    def __call__(
108        self,
109        logits_masked: torch.Tensor,
110        targets: torch.Tensor,
111        masks: torch.Tensor,
112        **kwargs,
113    ) -> list[torch.Tensor]:
114        """Calculates the overall loss based on the logits, targets, and masks.
115
116        Args:
117            logits (torch.Tensor): The predicted logits.
118            logits_masked (torch.Tensor): The predicted logits with the masks applied.
119            targets (torch.Tensor): The targets.
120            masks (torch.Tensor): The masks tensor.
121
122        Returns:
123            List[torch.Tensor]: A list of calculated losses, including the overall loss,
124            cross-entropy loss, area loss, and smoothness loss.
125        """
126        targets = targets.long()
127        variation_loss = torch.tensor(0)
128        area_loss = torch.tensor(0)
129        ce_loss = torch.tensor(0)
130        if self.smoothness_coeff > 0:
131            variation_loss = self.smoothness_loss(masks)
132        if self.area_coeff > 0:
133            area_loss = self.area_loss(masks)
134        if self.ce_coeff > 0:
135            ce_loss = cross_entropy(logits_masked, targets)
136        ce_loss = cross_entropy(logits_masked, targets)
137
138        loss = (
139            self.ce_coeff * ce_loss
140            + self.area_coeff * area_loss
141            + self.smoothness_coeff * variation_loss
142        )
143
144        return [loss, ce_loss, area_loss, variation_loss]

ClassicLoss is a class that represents the loss function used in TAME.

It calculates the loss based on cross-entropy, area loss, and smoothness loss.

Parameters
  • ce_coeff: The coefficient for the cross-entropy loss.
  • area_coeff: The coefficient for the area loss.
  • smoothness_coeff: The coefficient for the smoothness loss.
  • smoothness_power: The power for the smoothness loss.
  • smoothness_border_penalty: The penalty for the smoothness loss at the border.
  • area_loss_power: The power for the area loss.
  • num_terms: The number of terms returned by the loss function.
ClassicLoss( ce_coeff: float = 1.5, area_coeff: float = 2, smoothness_coeff: float = 0.01, smoothness_power: float = 2, smoothness_border_penalty: float = 0.3, area_loss_power: float = 0.3)
ce_coeff: float = 1.5
area_coeff: float = 2
smoothness_coeff: float = 0.01
smoothness_power: float = 2
smoothness_border_penalty: float = 0.3
area_loss_power: float = 0.3
num_terms: ClassVar[int] = 4
Inherited Members
Loss
area_loss
smoothness_loss