ptame.models.components.loss
1import abc 2from dataclasses import dataclass 3from typing import ClassVar 4 5import torch 6from torch.nn.functional import cross_entropy 7 8 9@dataclass 10class Loss(abc.ABC): 11 """Loss is an abstract class that represents a loss function for the TAME- 12 based family of explainabiltiy methods.""" 13 14 num_terms: ClassVar[int] 15 16 @abc.abstractmethod 17 def __call__( 18 self, 19 logits: torch.Tensor, 20 logits_masked: torch.Tensor, 21 targets: torch.Tensor, 22 masks: torch.Tensor, 23 epoch: int, 24 ) -> list[torch.Tensor]: 25 """Calculates the loss based on the logits, targets, and masks. 26 27 Args: 28 logits (torch.Tensor): The predicted logits. 29 logits_masked (torch.Tensor): The predicted logits with the masks applied. 30 targets (torch.Tensor): The targets. 31 masks (torch.Tensor): The masks tensor. 32 33 Returns: 34 List[torch.Tensor]: A list of calculated losses. 35 """ 36 37 def area_loss(self, masks): 38 """Calculates the area loss based on the masks. 39 40 Args: 41 masks (torch.Tensor): The masks tensor. 42 43 Returns: 44 torch.Tensor: The calculated area loss. 45 """ 46 if self.area_loss_power != 1: 47 # add e to prevent nan (derivative of sqrt at 0 is inf) 48 masks = (masks + 0.0005) ** self.area_loss_power 49 return torch.mean(masks) 50 51 def smoothness_loss(self, masks): 52 """Calculates the smoothness loss based on the masks. 53 54 Args: 55 masks (torch.Tensor): The masks tensor. 56 57 Returns: 58 torch.Tensor: The calculated smoothness loss. 59 """ 60 B, _, _, _ = masks.size() 61 border_penalty = self.smoothness_border_penalty 62 power = self.smoothness_power 63 x_loss = torch.sum( 64 (torch.abs(masks[:, :, 1:, :] - masks[:, :, :-1, :])) ** power 65 ) 66 y_loss = torch.sum( 67 (torch.abs(masks[:, :, :, 1:] - masks[:, :, :, :-1])) ** power 68 ) 69 if border_penalty > 0: 70 border = float(border_penalty) * torch.sum( 71 masks[:, :, -1, :] ** power 72 + masks[:, :, 0, :] ** power 73 + masks[:, :, :, -1] ** power 74 + masks[:, :, :, 0] ** power 75 ) 76 else: 77 border = 0.0 78 return (x_loss + y_loss + border) / float(power * B) 79 80 81@dataclass 82class ClassicLoss(Loss): 83 """ClassicLoss is a class that represents the loss function used in TAME. 84 85 It calculates the loss based on cross-entropy, area loss, and smoothness 86 loss. 87 88 :param ce_coeff: The coefficient for the cross-entropy loss. 89 :param area_coeff: The coefficient for the area loss. 90 :param smoothness_coeff: The coefficient for the smoothness loss. 91 :param smoothness_power: The power for the smoothness loss. 92 :param smoothness_border_penalty: The penalty for the smoothness loss at 93 the border. 94 :param area_loss_power: The power for the area loss. 95 :param num_terms: The number of terms returned by the loss function. 96 """ 97 98 ce_coeff: float = 1.5 99 area_coeff: float = 2 100 smoothness_coeff: float = 0.01 101 smoothness_power: float = 2 102 smoothness_border_penalty: float = 0.3 103 area_loss_power: float = 0.3 104 num_terms: ClassVar[int] = 4 105 106 def __call__( 107 self, 108 logits_masked: torch.Tensor, 109 targets: torch.Tensor, 110 masks: torch.Tensor, 111 **kwargs, 112 ) -> list[torch.Tensor]: 113 """Calculates the overall loss based on the logits, targets, and masks. 114 115 Args: 116 logits (torch.Tensor): The predicted logits. 117 logits_masked (torch.Tensor): The predicted logits with the masks applied. 118 targets (torch.Tensor): The targets. 119 masks (torch.Tensor): The masks tensor. 120 121 Returns: 122 List[torch.Tensor]: A list of calculated losses, including the overall loss, 123 cross-entropy loss, area loss, and smoothness loss. 124 """ 125 targets = targets.long() 126 variation_loss = torch.tensor(0) 127 area_loss = torch.tensor(0) 128 ce_loss = torch.tensor(0) 129 if self.smoothness_coeff > 0: 130 variation_loss = self.smoothness_loss(masks) 131 if self.area_coeff > 0: 132 area_loss = self.area_loss(masks) 133 if self.ce_coeff > 0: 134 ce_loss = cross_entropy(logits_masked, targets) 135 ce_loss = cross_entropy(logits_masked, targets) 136 137 loss = ( 138 self.ce_coeff * ce_loss 139 + self.area_coeff * area_loss 140 + self.smoothness_coeff * variation_loss 141 ) 142 143 return [loss, ce_loss, area_loss, variation_loss]
@dataclass
class
Loss10@dataclass 11class Loss(abc.ABC): 12 """Loss is an abstract class that represents a loss function for the TAME- 13 based family of explainabiltiy methods.""" 14 15 num_terms: ClassVar[int] 16 17 @abc.abstractmethod 18 def __call__( 19 self, 20 logits: torch.Tensor, 21 logits_masked: torch.Tensor, 22 targets: torch.Tensor, 23 masks: torch.Tensor, 24 epoch: int, 25 ) -> list[torch.Tensor]: 26 """Calculates the loss based on the logits, targets, and masks. 27 28 Args: 29 logits (torch.Tensor): The predicted logits. 30 logits_masked (torch.Tensor): The predicted logits with the masks applied. 31 targets (torch.Tensor): The targets. 32 masks (torch.Tensor): The masks tensor. 33 34 Returns: 35 List[torch.Tensor]: A list of calculated losses. 36 """ 37 38 def area_loss(self, masks): 39 """Calculates the area loss based on the masks. 40 41 Args: 42 masks (torch.Tensor): The masks tensor. 43 44 Returns: 45 torch.Tensor: The calculated area loss. 46 """ 47 if self.area_loss_power != 1: 48 # add e to prevent nan (derivative of sqrt at 0 is inf) 49 masks = (masks + 0.0005) ** self.area_loss_power 50 return torch.mean(masks) 51 52 def smoothness_loss(self, masks): 53 """Calculates the smoothness loss based on the masks. 54 55 Args: 56 masks (torch.Tensor): The masks tensor. 57 58 Returns: 59 torch.Tensor: The calculated smoothness loss. 60 """ 61 B, _, _, _ = masks.size() 62 border_penalty = self.smoothness_border_penalty 63 power = self.smoothness_power 64 x_loss = torch.sum( 65 (torch.abs(masks[:, :, 1:, :] - masks[:, :, :-1, :])) ** power 66 ) 67 y_loss = torch.sum( 68 (torch.abs(masks[:, :, :, 1:] - masks[:, :, :, :-1])) ** power 69 ) 70 if border_penalty > 0: 71 border = float(border_penalty) * torch.sum( 72 masks[:, :, -1, :] ** power 73 + masks[:, :, 0, :] ** power 74 + masks[:, :, :, -1] ** power 75 + masks[:, :, :, 0] ** power 76 ) 77 else: 78 border = 0.0 79 return (x_loss + y_loss + border) / float(power * B)
Loss is an abstract class that represents a loss function for the TAME- based family of explainabiltiy methods.
def
area_loss(self, masks):
38 def area_loss(self, masks): 39 """Calculates the area loss based on the masks. 40 41 Args: 42 masks (torch.Tensor): The masks tensor. 43 44 Returns: 45 torch.Tensor: The calculated area loss. 46 """ 47 if self.area_loss_power != 1: 48 # add e to prevent nan (derivative of sqrt at 0 is inf) 49 masks = (masks + 0.0005) ** self.area_loss_power 50 return torch.mean(masks)
Calculates the area loss based on the masks.
Args: masks (torch.Tensor): The masks tensor.
Returns: torch.Tensor: The calculated area loss.
def
smoothness_loss(self, masks):
52 def smoothness_loss(self, masks): 53 """Calculates the smoothness loss based on the masks. 54 55 Args: 56 masks (torch.Tensor): The masks tensor. 57 58 Returns: 59 torch.Tensor: The calculated smoothness loss. 60 """ 61 B, _, _, _ = masks.size() 62 border_penalty = self.smoothness_border_penalty 63 power = self.smoothness_power 64 x_loss = torch.sum( 65 (torch.abs(masks[:, :, 1:, :] - masks[:, :, :-1, :])) ** power 66 ) 67 y_loss = torch.sum( 68 (torch.abs(masks[:, :, :, 1:] - masks[:, :, :, :-1])) ** power 69 ) 70 if border_penalty > 0: 71 border = float(border_penalty) * torch.sum( 72 masks[:, :, -1, :] ** power 73 + masks[:, :, 0, :] ** power 74 + masks[:, :, :, -1] ** power 75 + masks[:, :, :, 0] ** power 76 ) 77 else: 78 border = 0.0 79 return (x_loss + y_loss + border) / float(power * B)
Calculates the smoothness loss based on the masks.
Args: masks (torch.Tensor): The masks tensor.
Returns: torch.Tensor: The calculated smoothness loss.
82@dataclass 83class ClassicLoss(Loss): 84 """ClassicLoss is a class that represents the loss function used in TAME. 85 86 It calculates the loss based on cross-entropy, area loss, and smoothness 87 loss. 88 89 :param ce_coeff: The coefficient for the cross-entropy loss. 90 :param area_coeff: The coefficient for the area loss. 91 :param smoothness_coeff: The coefficient for the smoothness loss. 92 :param smoothness_power: The power for the smoothness loss. 93 :param smoothness_border_penalty: The penalty for the smoothness loss at 94 the border. 95 :param area_loss_power: The power for the area loss. 96 :param num_terms: The number of terms returned by the loss function. 97 """ 98 99 ce_coeff: float = 1.5 100 area_coeff: float = 2 101 smoothness_coeff: float = 0.01 102 smoothness_power: float = 2 103 smoothness_border_penalty: float = 0.3 104 area_loss_power: float = 0.3 105 num_terms: ClassVar[int] = 4 106 107 def __call__( 108 self, 109 logits_masked: torch.Tensor, 110 targets: torch.Tensor, 111 masks: torch.Tensor, 112 **kwargs, 113 ) -> list[torch.Tensor]: 114 """Calculates the overall loss based on the logits, targets, and masks. 115 116 Args: 117 logits (torch.Tensor): The predicted logits. 118 logits_masked (torch.Tensor): The predicted logits with the masks applied. 119 targets (torch.Tensor): The targets. 120 masks (torch.Tensor): The masks tensor. 121 122 Returns: 123 List[torch.Tensor]: A list of calculated losses, including the overall loss, 124 cross-entropy loss, area loss, and smoothness loss. 125 """ 126 targets = targets.long() 127 variation_loss = torch.tensor(0) 128 area_loss = torch.tensor(0) 129 ce_loss = torch.tensor(0) 130 if self.smoothness_coeff > 0: 131 variation_loss = self.smoothness_loss(masks) 132 if self.area_coeff > 0: 133 area_loss = self.area_loss(masks) 134 if self.ce_coeff > 0: 135 ce_loss = cross_entropy(logits_masked, targets) 136 ce_loss = cross_entropy(logits_masked, targets) 137 138 loss = ( 139 self.ce_coeff * ce_loss 140 + self.area_coeff * area_loss 141 + self.smoothness_coeff * variation_loss 142 ) 143 144 return [loss, ce_loss, area_loss, variation_loss]
ClassicLoss is a class that represents the loss function used in TAME.
It calculates the loss based on cross-entropy, area loss, and smoothness loss.
Parameters
- ce_coeff: The coefficient for the cross-entropy loss.
- area_coeff: The coefficient for the area loss.
- smoothness_coeff: The coefficient for the smoothness loss.
- smoothness_power: The power for the smoothness loss.
- smoothness_border_penalty: The penalty for the smoothness loss at the border.
- area_loss_power: The power for the area loss.
- num_terms: The number of terms returned by the loss function.
ClassicLoss( ce_coeff: float = 1.5, area_coeff: float = 2, smoothness_coeff: float = 0.01, smoothness_power: float = 2, smoothness_border_penalty: float = 0.3, area_loss_power: float = 0.3)