ptame.models.components.ptame_pipeline
1from collections.abc import Callable 2 3import torch 4from torch import nn 5 6from ptame.utils.map_printer import XAIModel 7from ptame.utils.masking_utils import ( 8 norm_resize_mask, 9 random_map_select, 10 single_map_select, 11) 12 13 14class PtamePipeline(nn.Module, XAIModel): 15 """Perturbation-based Attention Mechanism for Explaining bLAck-box models 16 (PAMELA)""" 17 18 def __init__( 19 self, 20 backbone: nn.Module, 21 attention: nn.Module, 22 masking_procedure=norm_resize_mask, 23 train_map_select=random_map_select, 24 matching_map_select=single_map_select, 25 eval_map_select=None, 26 backbone_eval=True, 27 return_keys: list[str] = [ 28 "logits", 29 "masks", 30 "logits_masked", 31 "targets", 32 ], 33 hooks: dict[str, Callable] = {}, 34 **kwargs, 35 ): 36 super().__init__() 37 self.backbone = backbone 38 self.backbone.requires_grad_(False) 39 if backbone_eval: 40 self.backbone.eval() 41 self.attention = attention 42 self.masking_procedure = masking_procedure 43 self.train_map_select = train_map_select 44 self.matching_map_select = matching_map_select 45 self.eval_map_select = ( 46 a if (a := eval_map_select) is not None else matching_map_select 47 ) 48 self.default_pipeline = { 49 "backbone": self.backbone_step, 50 "attention": self.attention_step, 51 "masking": self.masking_step, 52 "masked_backbone": self.backbone_step, 53 "map_selection": self.selection_step, 54 } 55 phases = ["before_backbone"] 56 phases += [f"after_{step}" for step in self.default_pipeline.keys()] 57 self.hooks = {phase: [] for phase in phases} 58 for phase, hook in hooks.items(): 59 if phase in self.default_pipeline.keys(): 60 # a total override of a step 61 self.default_pipeline[phase] = hook( 62 self, self.default_pipeline[phase] 63 ) 64 elif phase in phases: 65 # a hook around a step 66 self.hooks[phase].append(hook) 67 else: 68 raise ValueError(f"Unknown phase: {phase}") 69 self.return_keys = return_keys 70 71 def _select_returned_maps(self, saliency_maps, targets): 72 """Select the saliency maps that will be returned based on the targets 73 and the stage.""" 74 if self.training: 75 return self.train_map_select(saliency_maps, targets) 76 else: 77 return self._select_matching_maps(saliency_maps, targets) 78 79 def _select_matching_maps(self, saliency_maps, targets): 80 """Get the saliency maps from the attention mechanism. 81 82 Depending on the stage and, masks as selected using a different 83 procedure. 84 """ 85 if self.training: 86 return self.matching_map_select(saliency_maps, targets) 87 else: 88 return self.eval_map_select(saliency_maps, targets) 89 90 def backbone_step(self, x: dict) -> dict: 91 """Forward pass of the backbone.""" 92 if (inp := x.get("masked_images", None)) is not None: 93 x["logits_masked"] = self.backbone(inp) 94 x["targets_masked"] = x["logits_masked"].argmax(dim=1) 95 else: 96 x["logits"] = self.backbone(x["images"]) 97 x["targets"] = x["logits"].argmax(dim=1) 98 return x 99 100 def attention_step(self, x: dict) -> dict: 101 """Forward pass of the attention mechanism.""" 102 x["maps"] = self.attention(**x) 103 return x 104 105 def masking_step(self, x: dict) -> dict: 106 """Apply the masks to the input tensor and get the output.""" 107 x["masked_images"] = self.masking_procedure( 108 x["images"], self._select_matching_maps(x["maps"], x["targets"]) 109 ) 110 return x 111 112 def selection_step(self, x: dict) -> dict: 113 """Select the saliency maps that will be returned based on the targets 114 and the stage.""" 115 x["masks"] = self._select_returned_maps(x["maps"], x["targets"]) 116 return x 117 118 def run_hooks(self, phase: str, x: dict) -> dict: 119 """Run the hooks for the given phase.""" 120 for hook in self.hooks[phase]: 121 x = hook(self, x) 122 return x 123 124 def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: 125 """Forward pass of the entire model. 126 127 Depending on the stage, different outputs are returned. 128 """ 129 data = {"images": x, "training": self.training} 130 data = self.run_hooks("before_backbone", data) 131 for step_name, step in self.default_pipeline.items(): 132 data = step(data) 133 data = self.run_hooks("after_" + step_name, data) 134 135 return { 136 return_key: data[return_key] for return_key in self.return_keys 137 } 138 139 def produce_map( 140 self, image: torch.Tensor, **kwargs 141 ) -> tuple[torch.Tensor, int]: 142 """Produce a saliency map for the given image.""" 143 data = {"images": image.unsqueeze(0), "training": False} 144 data = self.run_hooks("before_backbone", data) 145 for step_name, step in self.default_pipeline.items(): 146 data = step(data) 147 data = self.run_hooks("after_" + step_name, data) 148 if classes := kwargs.get("classes", False): 149 top_classes = torch.topk(data["logits"], classes)[1] 150 maps = data["maps"].squeeze()[top_classes] 151 return maps, top_classes 152 else: 153 map = data["masks"].squeeze() 154 return map, data["targets"].squeeze() 155 156 def produce_cdmaps( 157 self, image: torch.Tensor 158 ) -> tuple[list[torch.Tensor], int]: 159 """Produce a list of saliency maps for the given image.""" 160 image = image.unsqueeze(0) 161 out = self.backbone(image) 162 maps = self.attention(image) 163 predictions = out.squeeze() 164 return maps, predictions 165 166 def get_predictions(self, x): 167 """Get the predictions of the model. 168 169 Useful for the measures. 170 """ 171 return self.default_pipeline["backbone"]({"images": x})["logits"]
class
PtamePipeline(torch.nn.modules.module.Module, ptame.utils.map_printer.XAIModel):
15class PtamePipeline(nn.Module, XAIModel): 16 """Perturbation-based Attention Mechanism for Explaining bLAck-box models 17 (PAMELA)""" 18 19 def __init__( 20 self, 21 backbone: nn.Module, 22 attention: nn.Module, 23 masking_procedure=norm_resize_mask, 24 train_map_select=random_map_select, 25 matching_map_select=single_map_select, 26 eval_map_select=None, 27 backbone_eval=True, 28 return_keys: list[str] = [ 29 "logits", 30 "masks", 31 "logits_masked", 32 "targets", 33 ], 34 hooks: dict[str, Callable] = {}, 35 **kwargs, 36 ): 37 super().__init__() 38 self.backbone = backbone 39 self.backbone.requires_grad_(False) 40 if backbone_eval: 41 self.backbone.eval() 42 self.attention = attention 43 self.masking_procedure = masking_procedure 44 self.train_map_select = train_map_select 45 self.matching_map_select = matching_map_select 46 self.eval_map_select = ( 47 a if (a := eval_map_select) is not None else matching_map_select 48 ) 49 self.default_pipeline = { 50 "backbone": self.backbone_step, 51 "attention": self.attention_step, 52 "masking": self.masking_step, 53 "masked_backbone": self.backbone_step, 54 "map_selection": self.selection_step, 55 } 56 phases = ["before_backbone"] 57 phases += [f"after_{step}" for step in self.default_pipeline.keys()] 58 self.hooks = {phase: [] for phase in phases} 59 for phase, hook in hooks.items(): 60 if phase in self.default_pipeline.keys(): 61 # a total override of a step 62 self.default_pipeline[phase] = hook( 63 self, self.default_pipeline[phase] 64 ) 65 elif phase in phases: 66 # a hook around a step 67 self.hooks[phase].append(hook) 68 else: 69 raise ValueError(f"Unknown phase: {phase}") 70 self.return_keys = return_keys 71 72 def _select_returned_maps(self, saliency_maps, targets): 73 """Select the saliency maps that will be returned based on the targets 74 and the stage.""" 75 if self.training: 76 return self.train_map_select(saliency_maps, targets) 77 else: 78 return self._select_matching_maps(saliency_maps, targets) 79 80 def _select_matching_maps(self, saliency_maps, targets): 81 """Get the saliency maps from the attention mechanism. 82 83 Depending on the stage and, masks as selected using a different 84 procedure. 85 """ 86 if self.training: 87 return self.matching_map_select(saliency_maps, targets) 88 else: 89 return self.eval_map_select(saliency_maps, targets) 90 91 def backbone_step(self, x: dict) -> dict: 92 """Forward pass of the backbone.""" 93 if (inp := x.get("masked_images", None)) is not None: 94 x["logits_masked"] = self.backbone(inp) 95 x["targets_masked"] = x["logits_masked"].argmax(dim=1) 96 else: 97 x["logits"] = self.backbone(x["images"]) 98 x["targets"] = x["logits"].argmax(dim=1) 99 return x 100 101 def attention_step(self, x: dict) -> dict: 102 """Forward pass of the attention mechanism.""" 103 x["maps"] = self.attention(**x) 104 return x 105 106 def masking_step(self, x: dict) -> dict: 107 """Apply the masks to the input tensor and get the output.""" 108 x["masked_images"] = self.masking_procedure( 109 x["images"], self._select_matching_maps(x["maps"], x["targets"]) 110 ) 111 return x 112 113 def selection_step(self, x: dict) -> dict: 114 """Select the saliency maps that will be returned based on the targets 115 and the stage.""" 116 x["masks"] = self._select_returned_maps(x["maps"], x["targets"]) 117 return x 118 119 def run_hooks(self, phase: str, x: dict) -> dict: 120 """Run the hooks for the given phase.""" 121 for hook in self.hooks[phase]: 122 x = hook(self, x) 123 return x 124 125 def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: 126 """Forward pass of the entire model. 127 128 Depending on the stage, different outputs are returned. 129 """ 130 data = {"images": x, "training": self.training} 131 data = self.run_hooks("before_backbone", data) 132 for step_name, step in self.default_pipeline.items(): 133 data = step(data) 134 data = self.run_hooks("after_" + step_name, data) 135 136 return { 137 return_key: data[return_key] for return_key in self.return_keys 138 } 139 140 def produce_map( 141 self, image: torch.Tensor, **kwargs 142 ) -> tuple[torch.Tensor, int]: 143 """Produce a saliency map for the given image.""" 144 data = {"images": image.unsqueeze(0), "training": False} 145 data = self.run_hooks("before_backbone", data) 146 for step_name, step in self.default_pipeline.items(): 147 data = step(data) 148 data = self.run_hooks("after_" + step_name, data) 149 if classes := kwargs.get("classes", False): 150 top_classes = torch.topk(data["logits"], classes)[1] 151 maps = data["maps"].squeeze()[top_classes] 152 return maps, top_classes 153 else: 154 map = data["masks"].squeeze() 155 return map, data["targets"].squeeze() 156 157 def produce_cdmaps( 158 self, image: torch.Tensor 159 ) -> tuple[list[torch.Tensor], int]: 160 """Produce a list of saliency maps for the given image.""" 161 image = image.unsqueeze(0) 162 out = self.backbone(image) 163 maps = self.attention(image) 164 predictions = out.squeeze() 165 return maps, predictions 166 167 def get_predictions(self, x): 168 """Get the predictions of the model. 169 170 Useful for the measures. 171 """ 172 return self.default_pipeline["backbone"]({"images": x})["logits"]
Perturbation-based Attention Mechanism for Explaining bLAck-box models (PAMELA)
PtamePipeline( backbone: torch.nn.modules.module.Module, attention: torch.nn.modules.module.Module, masking_procedure=<function norm_resize_mask>, train_map_select=<function random_map_select>, matching_map_select=<function single_map_select>, eval_map_select=None, backbone_eval=True, return_keys: list[str] = ['logits', 'masks', 'logits_masked', 'targets'], hooks: dict[str, Callable] = {}, **kwargs)
19 def __init__( 20 self, 21 backbone: nn.Module, 22 attention: nn.Module, 23 masking_procedure=norm_resize_mask, 24 train_map_select=random_map_select, 25 matching_map_select=single_map_select, 26 eval_map_select=None, 27 backbone_eval=True, 28 return_keys: list[str] = [ 29 "logits", 30 "masks", 31 "logits_masked", 32 "targets", 33 ], 34 hooks: dict[str, Callable] = {}, 35 **kwargs, 36 ): 37 super().__init__() 38 self.backbone = backbone 39 self.backbone.requires_grad_(False) 40 if backbone_eval: 41 self.backbone.eval() 42 self.attention = attention 43 self.masking_procedure = masking_procedure 44 self.train_map_select = train_map_select 45 self.matching_map_select = matching_map_select 46 self.eval_map_select = ( 47 a if (a := eval_map_select) is not None else matching_map_select 48 ) 49 self.default_pipeline = { 50 "backbone": self.backbone_step, 51 "attention": self.attention_step, 52 "masking": self.masking_step, 53 "masked_backbone": self.backbone_step, 54 "map_selection": self.selection_step, 55 } 56 phases = ["before_backbone"] 57 phases += [f"after_{step}" for step in self.default_pipeline.keys()] 58 self.hooks = {phase: [] for phase in phases} 59 for phase, hook in hooks.items(): 60 if phase in self.default_pipeline.keys(): 61 # a total override of a step 62 self.default_pipeline[phase] = hook( 63 self, self.default_pipeline[phase] 64 ) 65 elif phase in phases: 66 # a hook around a step 67 self.hooks[phase].append(hook) 68 else: 69 raise ValueError(f"Unknown phase: {phase}") 70 self.return_keys = return_keys
Initialize internal Module state, shared by both nn.Module and ScriptModule.
def
backbone_step(self, x: dict) -> dict:
91 def backbone_step(self, x: dict) -> dict: 92 """Forward pass of the backbone.""" 93 if (inp := x.get("masked_images", None)) is not None: 94 x["logits_masked"] = self.backbone(inp) 95 x["targets_masked"] = x["logits_masked"].argmax(dim=1) 96 else: 97 x["logits"] = self.backbone(x["images"]) 98 x["targets"] = x["logits"].argmax(dim=1) 99 return x
Forward pass of the backbone.
def
attention_step(self, x: dict) -> dict:
101 def attention_step(self, x: dict) -> dict: 102 """Forward pass of the attention mechanism.""" 103 x["maps"] = self.attention(**x) 104 return x
Forward pass of the attention mechanism.
def
masking_step(self, x: dict) -> dict:
106 def masking_step(self, x: dict) -> dict: 107 """Apply the masks to the input tensor and get the output.""" 108 x["masked_images"] = self.masking_procedure( 109 x["images"], self._select_matching_maps(x["maps"], x["targets"]) 110 ) 111 return x
Apply the masks to the input tensor and get the output.
def
selection_step(self, x: dict) -> dict:
113 def selection_step(self, x: dict) -> dict: 114 """Select the saliency maps that will be returned based on the targets 115 and the stage.""" 116 x["masks"] = self._select_returned_maps(x["maps"], x["targets"]) 117 return x
Select the saliency maps that will be returned based on the targets and the stage.
def
run_hooks(self, phase: str, x: dict) -> dict:
119 def run_hooks(self, phase: str, x: dict) -> dict: 120 """Run the hooks for the given phase.""" 121 for hook in self.hooks[phase]: 122 x = hook(self, x) 123 return x
Run the hooks for the given phase.
def
forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
125 def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: 126 """Forward pass of the entire model. 127 128 Depending on the stage, different outputs are returned. 129 """ 130 data = {"images": x, "training": self.training} 131 data = self.run_hooks("before_backbone", data) 132 for step_name, step in self.default_pipeline.items(): 133 data = step(data) 134 data = self.run_hooks("after_" + step_name, data) 135 136 return { 137 return_key: data[return_key] for return_key in self.return_keys 138 }
Forward pass of the entire model.
Depending on the stage, different outputs are returned.
def
produce_map(self, image: torch.Tensor, **kwargs) -> tuple[torch.Tensor, int]:
140 def produce_map( 141 self, image: torch.Tensor, **kwargs 142 ) -> tuple[torch.Tensor, int]: 143 """Produce a saliency map for the given image.""" 144 data = {"images": image.unsqueeze(0), "training": False} 145 data = self.run_hooks("before_backbone", data) 146 for step_name, step in self.default_pipeline.items(): 147 data = step(data) 148 data = self.run_hooks("after_" + step_name, data) 149 if classes := kwargs.get("classes", False): 150 top_classes = torch.topk(data["logits"], classes)[1] 151 maps = data["maps"].squeeze()[top_classes] 152 return maps, top_classes 153 else: 154 map = data["masks"].squeeze() 155 return map, data["targets"].squeeze()
Produce a saliency map for the given image.
def
produce_cdmaps(self, image: torch.Tensor) -> tuple[list[torch.Tensor], int]:
157 def produce_cdmaps( 158 self, image: torch.Tensor 159 ) -> tuple[list[torch.Tensor], int]: 160 """Produce a list of saliency maps for the given image.""" 161 image = image.unsqueeze(0) 162 out = self.backbone(image) 163 maps = self.attention(image) 164 predictions = out.squeeze() 165 return maps, predictions
Produce a list of saliency maps for the given image.