ptame.models.components.ttame
1from typing import List, Tuple, Type 2 3import torch 4from torch import nn 5from torchvision.models.feature_extraction import ( 6 create_feature_extractor, 7 get_graph_node_names, 8) 9 10from ptame.models.components.ptame_pipeline import PtamePipeline 11from ptame.utils.map_printer import XAIModel 12 13 14class TTAME(PtamePipeline, XAIModel): 15 """Transformer-compatible Trainable Attention Mechanism for 16 Explainability.""" 17 18 def __init__( 19 self, 20 backbone: nn.Module, 21 attention: nn.Module, 22 masking_procedure, 23 ): 24 super().__init__(backbone, attention, masking_procedure) 25 26 def _forward_backbone( 27 self, x: torch.Tensor 28 ) -> Tuple[torch.Tensor, List[torch.Tensor]]: 29 """Forward pass of the backbone (feature extractor).""" 30 y, features = self.backbone(x) 31 return y, features 32 33 def _forward_second( 34 self, x: torch.Tensor, features: List[torch.Tensor], targets 35 ): 36 """Second forward pass, for the training stage only.""" 37 # Get the saliency maps from the attention mechanism 38 saliency_maps = self.attention(features) 39 # Apply the masks to the input tensor and get the output 40 x_masked = self.masking_procedure( 41 x, self._select_matching_maps(saliency_maps, targets) 42 ) 43 y_masked, _ = self.backbone(x_masked) 44 # Select the saliency maps which will be returned based on the targets and the stage 45 saliency_maps = self._select_returned_maps(saliency_maps, targets) 46 return y_masked, saliency_maps 47 48 def forward(self, x: torch.Tensor): 49 """Forward pass of the entire model. 50 51 Depending on the stage, different outputs are returned. 52 """ 53 y, features = self._forward_backbone(x.clone().detach()) 54 targets = y.argmax(dim=1) 55 y_masked, maps = self._forward_second(x, features, targets) 56 return { 57 "logits": y, 58 "logits_masked": y_masked, 59 "targets": targets, 60 "masks": maps, 61 } 62 63 def get_predictions(self, x): 64 """Get the predictions of the model. 65 66 Useful for the measures. 67 """ 68 return self._forward_backbone(x)[0] 69 70 def produce_map(self, image: torch.Tensor) -> Tuple[torch.Tensor, int]: 71 image = image.unsqueeze(0) 72 out, features = self._forward_backbone(image) 73 maps = self.attention(features) 74 prediction = out[0].argmax().item() 75 map = maps[0, prediction] 76 return map, prediction 77 78 def produce_cdmaps( 79 self, image: torch.Tensor 80 ) -> Tuple[List[torch.Tensor], int]: 81 image = image.unsqueeze(0) 82 out, features = self._forward_backbone(image) 83 maps = self.attention(features) 84 predictions = out[0] 85 return maps, predictions 86 87 88class TTAMEBuilder: 89 """Builder for the TTAME model.""" 90 91 def __init__( 92 self, 93 backbone: nn.Module, 94 attention: Type[nn.Module], 95 masking_procedure, 96 layers: List[str], 97 input_dim: List[int] = None, 98 num_classes: int = 1000, 99 **kwargs, 100 ): 101 self.backbone = backbone 102 self.attention = attention 103 self.masking_procedure = masking_procedure 104 self.layers = layers 105 self.input_dim = input_dim 106 self.num_classes = num_classes 107 self.kwargs = kwargs 108 109 @torch.no_grad() 110 def build(self) -> TTAME: 111 """Build the TTAME model.""" 112 # build the feature extractor 113 self._build_fx() 114 # build the attention mechanism 115 self._build_attention() 116 # build the model 117 return TTAME( 118 self.backbone, 119 self.attention, 120 self.masking_procedure, 121 ) 122 123 @torch.no_grad() 124 def build_pipeline(self) -> TTAME: 125 """Build the TTAME model.""" 126 # build the feature extractor 127 self._build_fx() 128 # build the attention mechanism 129 self._build_attention() 130 131 def ttame_backbone_step(pipeline, _default_step): 132 def backbone_step(x: dict) -> dict: 133 """Forward pass of the backbone.""" 134 if (inp := x.get("masked_images", None)) is not None: 135 x["logits_masked"], _ = pipeline.backbone(inp) 136 x["targets_masked"] = x["logits_masked"].argmax(dim=1) 137 else: 138 x["logits"], x["features"] = pipeline.backbone(x["images"]) 139 x["targets"] = x["logits"].argmax(dim=1) 140 return x 141 142 return backbone_step 143 144 self.kwargs["hooks"] = { 145 "backbone": ttame_backbone_step, 146 "masked_backbone": ttame_backbone_step, 147 } | self.kwargs.get("hooks", {}) 148 # build the model 149 return PtamePipeline( 150 self.backbone, 151 self.attention, 152 self.masking_procedure, 153 **self.kwargs, 154 ) 155 156 def _build_fx(self): 157 """Build feature extractor.""" 158 # if no layers are specified, print layers and quit 159 train_names, eval_names = get_graph_node_names(self.backbone) 160 if self.layers == [] or self.layers is None: 161 print(train_names) 162 quit() 163 # get the output layer name 164 output = (train_names[-1], eval_names[-1]) 165 if output[0] != output[1]: 166 print( 167 "WARNING! THIS MODEL HAS DIFFERENT OUTPUTS FOR TRAIN AND EVAL MODE" 168 ) 169 self.output_name = output[0] 170 # get feature extractor 171 self.backbone = create_feature_extractor( 172 self.backbone, return_nodes=(self.layers + [self.output_name]) 173 ) 174 # Dry run to get number of channels of each layer for the attention mechanism 175 if self.input_dim is not None: 176 inp = torch.randn(self.input_dim) 177 else: 178 inp = torch.randn(1, 3, 224, 224) 179 self.backbone.eval() 180 outputs = self.backbone(inp) 181 outputs.pop(self.output_name) 182 features = outputs.values() 183 self.feature_size = [o.shape for o in features] 184 self.backbone.register_forward_hook(self._simplify_graph_outputs()) 185 186 def _build_attention(self): 187 """Build the attention mechanism.""" 188 # check if the model is a transformer 189 if self._is_transformer(): 190 feature_size = [ 191 torch.Size([2, ft[-1], 14, 14]) for ft in self.feature_size 192 ] 193 self.attention = self.attention(feature_size, self.num_classes) 194 self.attention.register_forward_pre_hook( 195 self._feature_adapter_vit_b_16(), with_kwargs=True 196 ) 197 else: 198 self.attention = self.attention( 199 self.feature_size, self.num_classes 200 ) 201 202 def _is_transformer(self) -> bool: 203 """Check if the model is a transformer. 204 205 Returns: 206 bool: True if the model is a transformer, False otherwise. 207 """ 208 for module in self.backbone.modules(): 209 if isinstance(module, nn.MultiheadAttention): 210 return True 211 return False 212 213 def _simplify_graph_outputs(self): 214 """Returns a hook to simplify the outputs of the GraphModule feature 215 extractor.""" 216 217 def hook( 218 module, inputs, outputs 219 ) -> Tuple[torch.Tensor, List[torch.Tensor]]: 220 y: torch.Tensor = outputs.pop(self.output_name) 221 features = list(outputs.values()) 222 return y, features 223 224 return hook 225 226 def _feature_adapter_vit_b_16(self): 227 """Returns a hook to adapt the features of the Vision Transformer 228 model.""" 229 230 def hook(module, args, kwargs) -> List[torch.Tensor]: 231 seq_list = kwargs["features"] 232 # discard class token 233 seq_list = [seq[:, 1:, :] for seq in seq_list] 234 # reshape 235 seq_list = [ 236 seq.reshape(seq.size(0), 14, 14, seq.size(2)) 237 for seq in seq_list 238 ] 239 # bring channels after batch dimension 240 seq_list = [ 241 seq.transpose(2, 3).transpose(1, 2) for seq in seq_list 242 ] 243 kwargs["features"] = seq_list 244 return args, kwargs 245 246 return hook 247 248 249def build_ttame(**kwargs) -> TTAME: 250 """Build the TTAME model. 251 252 Args: 253 Same as TTAMEBuilder. 254 Returns: 255 TTAME: The TTAME model. 256 """ 257 builder = TTAMEBuilder(**kwargs) 258 return builder.build() 259 260 261def build_ttame_pipeline(**kwargs) -> TTAME: 262 """Build the TTAME model. 263 264 Args: 265 Same as TTAMEBuilder. 266 Returns: 267 TTAME: The TTAME model. 268 """ 269 builder = TTAMEBuilder(**kwargs) 270 return builder.build_pipeline()
class
TTAME(ptame.models.components.ptame_pipeline.PtamePipeline, ptame.utils.map_printer.XAIModel):
15class TTAME(PtamePipeline, XAIModel): 16 """Transformer-compatible Trainable Attention Mechanism for 17 Explainability.""" 18 19 def __init__( 20 self, 21 backbone: nn.Module, 22 attention: nn.Module, 23 masking_procedure, 24 ): 25 super().__init__(backbone, attention, masking_procedure) 26 27 def _forward_backbone( 28 self, x: torch.Tensor 29 ) -> Tuple[torch.Tensor, List[torch.Tensor]]: 30 """Forward pass of the backbone (feature extractor).""" 31 y, features = self.backbone(x) 32 return y, features 33 34 def _forward_second( 35 self, x: torch.Tensor, features: List[torch.Tensor], targets 36 ): 37 """Second forward pass, for the training stage only.""" 38 # Get the saliency maps from the attention mechanism 39 saliency_maps = self.attention(features) 40 # Apply the masks to the input tensor and get the output 41 x_masked = self.masking_procedure( 42 x, self._select_matching_maps(saliency_maps, targets) 43 ) 44 y_masked, _ = self.backbone(x_masked) 45 # Select the saliency maps which will be returned based on the targets and the stage 46 saliency_maps = self._select_returned_maps(saliency_maps, targets) 47 return y_masked, saliency_maps 48 49 def forward(self, x: torch.Tensor): 50 """Forward pass of the entire model. 51 52 Depending on the stage, different outputs are returned. 53 """ 54 y, features = self._forward_backbone(x.clone().detach()) 55 targets = y.argmax(dim=1) 56 y_masked, maps = self._forward_second(x, features, targets) 57 return { 58 "logits": y, 59 "logits_masked": y_masked, 60 "targets": targets, 61 "masks": maps, 62 } 63 64 def get_predictions(self, x): 65 """Get the predictions of the model. 66 67 Useful for the measures. 68 """ 69 return self._forward_backbone(x)[0] 70 71 def produce_map(self, image: torch.Tensor) -> Tuple[torch.Tensor, int]: 72 image = image.unsqueeze(0) 73 out, features = self._forward_backbone(image) 74 maps = self.attention(features) 75 prediction = out[0].argmax().item() 76 map = maps[0, prediction] 77 return map, prediction 78 79 def produce_cdmaps( 80 self, image: torch.Tensor 81 ) -> Tuple[List[torch.Tensor], int]: 82 image = image.unsqueeze(0) 83 out, features = self._forward_backbone(image) 84 maps = self.attention(features) 85 predictions = out[0] 86 return maps, predictions
Transformer-compatible Trainable Attention Mechanism for Explainability.
TTAME( backbone: torch.nn.modules.module.Module, attention: torch.nn.modules.module.Module, masking_procedure)
19 def __init__( 20 self, 21 backbone: nn.Module, 22 attention: nn.Module, 23 masking_procedure, 24 ): 25 super().__init__(backbone, attention, masking_procedure)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
def
forward(self, x: torch.Tensor):
49 def forward(self, x: torch.Tensor): 50 """Forward pass of the entire model. 51 52 Depending on the stage, different outputs are returned. 53 """ 54 y, features = self._forward_backbone(x.clone().detach()) 55 targets = y.argmax(dim=1) 56 y_masked, maps = self._forward_second(x, features, targets) 57 return { 58 "logits": y, 59 "logits_masked": y_masked, 60 "targets": targets, 61 "masks": maps, 62 }
Forward pass of the entire model.
Depending on the stage, different outputs are returned.
def
get_predictions(self, x):
64 def get_predictions(self, x): 65 """Get the predictions of the model. 66 67 Useful for the measures. 68 """ 69 return self._forward_backbone(x)[0]
Get the predictions of the model.
Useful for the measures.
def
produce_map(self, image: torch.Tensor) -> Tuple[torch.Tensor, int]:
71 def produce_map(self, image: torch.Tensor) -> Tuple[torch.Tensor, int]: 72 image = image.unsqueeze(0) 73 out, features = self._forward_backbone(image) 74 maps = self.attention(features) 75 prediction = out[0].argmax().item() 76 map = maps[0, prediction] 77 return map, prediction
Produce a saliency map for the given image.
def
produce_cdmaps(self, image: torch.Tensor) -> Tuple[List[torch.Tensor], int]:
79 def produce_cdmaps( 80 self, image: torch.Tensor 81 ) -> Tuple[List[torch.Tensor], int]: 82 image = image.unsqueeze(0) 83 out, features = self._forward_backbone(image) 84 maps = self.attention(features) 85 predictions = out[0] 86 return maps, predictions
Produce a list of saliency maps for the given image.
class
TTAMEBuilder:
89class TTAMEBuilder: 90 """Builder for the TTAME model.""" 91 92 def __init__( 93 self, 94 backbone: nn.Module, 95 attention: Type[nn.Module], 96 masking_procedure, 97 layers: List[str], 98 input_dim: List[int] = None, 99 num_classes: int = 1000, 100 **kwargs, 101 ): 102 self.backbone = backbone 103 self.attention = attention 104 self.masking_procedure = masking_procedure 105 self.layers = layers 106 self.input_dim = input_dim 107 self.num_classes = num_classes 108 self.kwargs = kwargs 109 110 @torch.no_grad() 111 def build(self) -> TTAME: 112 """Build the TTAME model.""" 113 # build the feature extractor 114 self._build_fx() 115 # build the attention mechanism 116 self._build_attention() 117 # build the model 118 return TTAME( 119 self.backbone, 120 self.attention, 121 self.masking_procedure, 122 ) 123 124 @torch.no_grad() 125 def build_pipeline(self) -> TTAME: 126 """Build the TTAME model.""" 127 # build the feature extractor 128 self._build_fx() 129 # build the attention mechanism 130 self._build_attention() 131 132 def ttame_backbone_step(pipeline, _default_step): 133 def backbone_step(x: dict) -> dict: 134 """Forward pass of the backbone.""" 135 if (inp := x.get("masked_images", None)) is not None: 136 x["logits_masked"], _ = pipeline.backbone(inp) 137 x["targets_masked"] = x["logits_masked"].argmax(dim=1) 138 else: 139 x["logits"], x["features"] = pipeline.backbone(x["images"]) 140 x["targets"] = x["logits"].argmax(dim=1) 141 return x 142 143 return backbone_step 144 145 self.kwargs["hooks"] = { 146 "backbone": ttame_backbone_step, 147 "masked_backbone": ttame_backbone_step, 148 } | self.kwargs.get("hooks", {}) 149 # build the model 150 return PtamePipeline( 151 self.backbone, 152 self.attention, 153 self.masking_procedure, 154 **self.kwargs, 155 ) 156 157 def _build_fx(self): 158 """Build feature extractor.""" 159 # if no layers are specified, print layers and quit 160 train_names, eval_names = get_graph_node_names(self.backbone) 161 if self.layers == [] or self.layers is None: 162 print(train_names) 163 quit() 164 # get the output layer name 165 output = (train_names[-1], eval_names[-1]) 166 if output[0] != output[1]: 167 print( 168 "WARNING! THIS MODEL HAS DIFFERENT OUTPUTS FOR TRAIN AND EVAL MODE" 169 ) 170 self.output_name = output[0] 171 # get feature extractor 172 self.backbone = create_feature_extractor( 173 self.backbone, return_nodes=(self.layers + [self.output_name]) 174 ) 175 # Dry run to get number of channels of each layer for the attention mechanism 176 if self.input_dim is not None: 177 inp = torch.randn(self.input_dim) 178 else: 179 inp = torch.randn(1, 3, 224, 224) 180 self.backbone.eval() 181 outputs = self.backbone(inp) 182 outputs.pop(self.output_name) 183 features = outputs.values() 184 self.feature_size = [o.shape for o in features] 185 self.backbone.register_forward_hook(self._simplify_graph_outputs()) 186 187 def _build_attention(self): 188 """Build the attention mechanism.""" 189 # check if the model is a transformer 190 if self._is_transformer(): 191 feature_size = [ 192 torch.Size([2, ft[-1], 14, 14]) for ft in self.feature_size 193 ] 194 self.attention = self.attention(feature_size, self.num_classes) 195 self.attention.register_forward_pre_hook( 196 self._feature_adapter_vit_b_16(), with_kwargs=True 197 ) 198 else: 199 self.attention = self.attention( 200 self.feature_size, self.num_classes 201 ) 202 203 def _is_transformer(self) -> bool: 204 """Check if the model is a transformer. 205 206 Returns: 207 bool: True if the model is a transformer, False otherwise. 208 """ 209 for module in self.backbone.modules(): 210 if isinstance(module, nn.MultiheadAttention): 211 return True 212 return False 213 214 def _simplify_graph_outputs(self): 215 """Returns a hook to simplify the outputs of the GraphModule feature 216 extractor.""" 217 218 def hook( 219 module, inputs, outputs 220 ) -> Tuple[torch.Tensor, List[torch.Tensor]]: 221 y: torch.Tensor = outputs.pop(self.output_name) 222 features = list(outputs.values()) 223 return y, features 224 225 return hook 226 227 def _feature_adapter_vit_b_16(self): 228 """Returns a hook to adapt the features of the Vision Transformer 229 model.""" 230 231 def hook(module, args, kwargs) -> List[torch.Tensor]: 232 seq_list = kwargs["features"] 233 # discard class token 234 seq_list = [seq[:, 1:, :] for seq in seq_list] 235 # reshape 236 seq_list = [ 237 seq.reshape(seq.size(0), 14, 14, seq.size(2)) 238 for seq in seq_list 239 ] 240 # bring channels after batch dimension 241 seq_list = [ 242 seq.transpose(2, 3).transpose(1, 2) for seq in seq_list 243 ] 244 kwargs["features"] = seq_list 245 return args, kwargs 246 247 return hook
Builder for the TTAME model.
TTAMEBuilder( backbone: torch.nn.modules.module.Module, attention: Type[torch.nn.modules.module.Module], masking_procedure, layers: List[str], input_dim: List[int] = None, num_classes: int = 1000, **kwargs)
92 def __init__( 93 self, 94 backbone: nn.Module, 95 attention: Type[nn.Module], 96 masking_procedure, 97 layers: List[str], 98 input_dim: List[int] = None, 99 num_classes: int = 1000, 100 **kwargs, 101 ): 102 self.backbone = backbone 103 self.attention = attention 104 self.masking_procedure = masking_procedure 105 self.layers = layers 106 self.input_dim = input_dim 107 self.num_classes = num_classes 108 self.kwargs = kwargs
110 @torch.no_grad() 111 def build(self) -> TTAME: 112 """Build the TTAME model.""" 113 # build the feature extractor 114 self._build_fx() 115 # build the attention mechanism 116 self._build_attention() 117 # build the model 118 return TTAME( 119 self.backbone, 120 self.attention, 121 self.masking_procedure, 122 )
Build the TTAME model.
124 @torch.no_grad() 125 def build_pipeline(self) -> TTAME: 126 """Build the TTAME model.""" 127 # build the feature extractor 128 self._build_fx() 129 # build the attention mechanism 130 self._build_attention() 131 132 def ttame_backbone_step(pipeline, _default_step): 133 def backbone_step(x: dict) -> dict: 134 """Forward pass of the backbone.""" 135 if (inp := x.get("masked_images", None)) is not None: 136 x["logits_masked"], _ = pipeline.backbone(inp) 137 x["targets_masked"] = x["logits_masked"].argmax(dim=1) 138 else: 139 x["logits"], x["features"] = pipeline.backbone(x["images"]) 140 x["targets"] = x["logits"].argmax(dim=1) 141 return x 142 143 return backbone_step 144 145 self.kwargs["hooks"] = { 146 "backbone": ttame_backbone_step, 147 "masked_backbone": ttame_backbone_step, 148 } | self.kwargs.get("hooks", {}) 149 # build the model 150 return PtamePipeline( 151 self.backbone, 152 self.attention, 153 self.masking_procedure, 154 **self.kwargs, 155 )
Build the TTAME model.
250def build_ttame(**kwargs) -> TTAME: 251 """Build the TTAME model. 252 253 Args: 254 Same as TTAMEBuilder. 255 Returns: 256 TTAME: The TTAME model. 257 """ 258 builder = TTAMEBuilder(**kwargs) 259 return builder.build()
Build the TTAME model.
Args: Same as TTAMEBuilder. Returns: TTAME: The TTAME model.
262def build_ttame_pipeline(**kwargs) -> TTAME: 263 """Build the TTAME model. 264 265 Args: 266 Same as TTAMEBuilder. 267 Returns: 268 TTAME: The TTAME model. 269 """ 270 builder = TTAMEBuilder(**kwargs) 271 return builder.build_pipeline()
Build the TTAME model.
Args: Same as TTAMEBuilder. Returns: TTAME: The TTAME model.