ptame.models.components.ptame_pipeline

  1from collections.abc import Callable
  2
  3import torch
  4from torch import nn
  5
  6from ptame.utils.map_printer import XAIModel
  7from ptame.utils.masking_utils import (
  8    norm_resize_mask,
  9    random_map_select,
 10    single_map_select,
 11)
 12
 13
 14class PtamePipeline(nn.Module, XAIModel):
 15    """Perturbation-based Attention Mechanism for Explaining bLAck-box models
 16    (PAMELA)"""
 17
 18    def __init__(
 19        self,
 20        backbone: nn.Module,
 21        attention: nn.Module,
 22        masking_procedure=norm_resize_mask,
 23        train_map_select=random_map_select,
 24        matching_map_select=single_map_select,
 25        eval_map_select=None,
 26        backbone_eval=True,
 27        return_keys: list[str] = [
 28            "logits",
 29            "masks",
 30            "logits_masked",
 31            "targets",
 32        ],
 33        hooks: dict[str, Callable] = {},
 34        **kwargs,
 35    ):
 36        super().__init__()
 37        self.backbone = backbone
 38        self.backbone.requires_grad_(False)
 39        if backbone_eval:
 40            self.backbone.eval()
 41        self.attention = attention
 42        self.masking_procedure = masking_procedure
 43        self.train_map_select = train_map_select
 44        self.matching_map_select = matching_map_select
 45        self.eval_map_select = (
 46            a if (a := eval_map_select) is not None else matching_map_select
 47        )
 48        self.default_pipeline = {
 49            "backbone": self.backbone_step,
 50            "attention": self.attention_step,
 51            "masking": self.masking_step,
 52            "masked_backbone": self.backbone_step,
 53            "map_selection": self.selection_step,
 54        }
 55        phases = ["before_backbone"]
 56        phases += [f"after_{step}" for step in self.default_pipeline.keys()]
 57        self.hooks = {phase: [] for phase in phases}
 58        for phase, hook in hooks.items():
 59            if phase in self.default_pipeline.keys():
 60                # a total override of a step
 61                self.default_pipeline[phase] = hook(
 62                    self, self.default_pipeline[phase]
 63                )
 64            elif phase in phases:
 65                # a hook around a step
 66                self.hooks[phase].append(hook)
 67            else:
 68                raise ValueError(f"Unknown phase: {phase}")
 69        self.return_keys = return_keys
 70
 71    def _select_returned_maps(self, saliency_maps, targets):
 72        """Select the saliency maps that will be returned based on the targets
 73        and the stage."""
 74        if self.training:
 75            return self.train_map_select(saliency_maps, targets)
 76        else:
 77            return self._select_matching_maps(saliency_maps, targets)
 78
 79    def _select_matching_maps(self, saliency_maps, targets):
 80        """Get the saliency maps from the attention mechanism.
 81
 82        Depending on the stage and, masks as selected using a different
 83        procedure.
 84        """
 85        if self.training:
 86            return self.matching_map_select(saliency_maps, targets)
 87        else:
 88            return self.eval_map_select(saliency_maps, targets)
 89
 90    def backbone_step(self, x: dict) -> dict:
 91        """Forward pass of the backbone."""
 92        if (inp := x.get("masked_images", None)) is not None:
 93            x["logits_masked"] = self.backbone(inp)
 94            x["targets_masked"] = x["logits_masked"].argmax(dim=1)
 95        else:
 96            x["logits"] = self.backbone(x["images"])
 97            x["targets"] = x["logits"].argmax(dim=1)
 98        return x
 99
100    def attention_step(self, x: dict) -> dict:
101        """Forward pass of the attention mechanism."""
102        x["maps"] = self.attention(**x)
103        return x
104
105    def masking_step(self, x: dict) -> dict:
106        """Apply the masks to the input tensor and get the output."""
107        x["masked_images"] = self.masking_procedure(
108            x["images"], self._select_matching_maps(x["maps"], x["targets"])
109        )
110        return x
111
112    def selection_step(self, x: dict) -> dict:
113        """Select the saliency maps that will be returned based on the targets
114        and the stage."""
115        x["masks"] = self._select_returned_maps(x["maps"], x["targets"])
116        return x
117
118    def run_hooks(self, phase: str, x: dict) -> dict:
119        """Run the hooks for the given phase."""
120        for hook in self.hooks[phase]:
121            x = hook(self, x)
122        return x
123
124    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
125        """Forward pass of the entire model.
126
127        Depending on the stage, different outputs are returned.
128        """
129        data = {"images": x, "training": self.training}
130        data = self.run_hooks("before_backbone", data)
131        for step_name, step in self.default_pipeline.items():
132            data = step(data)
133            data = self.run_hooks("after_" + step_name, data)
134
135        return {
136            return_key: data[return_key] for return_key in self.return_keys
137        }
138
139    def produce_map(
140        self, image: torch.Tensor, **kwargs
141    ) -> tuple[torch.Tensor, int]:
142        """Produce a saliency map for the given image."""
143        data = {"images": image.unsqueeze(0), "training": False}
144        data = self.run_hooks("before_backbone", data)
145        for step_name, step in self.default_pipeline.items():
146            data = step(data)
147            data = self.run_hooks("after_" + step_name, data)
148        if classes := kwargs.get("classes", False):
149            top_classes = torch.topk(data["logits"], classes)[1]
150            maps = data["maps"].squeeze()[top_classes]
151            return maps, top_classes
152        else:
153            map = data["masks"].squeeze()
154            return map, data["targets"].squeeze()
155
156    def produce_cdmaps(
157        self, image: torch.Tensor
158    ) -> tuple[list[torch.Tensor], int]:
159        """Produce a list of saliency maps for the given image."""
160        image = image.unsqueeze(0)
161        out = self.backbone(image)
162        maps = self.attention(image)
163        predictions = out.squeeze()
164        return maps, predictions
165
166    def get_predictions(self, x):
167        """Get the predictions of the model.
168
169        Useful for the measures.
170        """
171        return self.default_pipeline["backbone"]({"images": x})["logits"]
class PtamePipeline(torch.nn.modules.module.Module, ptame.utils.map_printer.XAIModel):
 15class PtamePipeline(nn.Module, XAIModel):
 16    """Perturbation-based Attention Mechanism for Explaining bLAck-box models
 17    (PAMELA)"""
 18
 19    def __init__(
 20        self,
 21        backbone: nn.Module,
 22        attention: nn.Module,
 23        masking_procedure=norm_resize_mask,
 24        train_map_select=random_map_select,
 25        matching_map_select=single_map_select,
 26        eval_map_select=None,
 27        backbone_eval=True,
 28        return_keys: list[str] = [
 29            "logits",
 30            "masks",
 31            "logits_masked",
 32            "targets",
 33        ],
 34        hooks: dict[str, Callable] = {},
 35        **kwargs,
 36    ):
 37        super().__init__()
 38        self.backbone = backbone
 39        self.backbone.requires_grad_(False)
 40        if backbone_eval:
 41            self.backbone.eval()
 42        self.attention = attention
 43        self.masking_procedure = masking_procedure
 44        self.train_map_select = train_map_select
 45        self.matching_map_select = matching_map_select
 46        self.eval_map_select = (
 47            a if (a := eval_map_select) is not None else matching_map_select
 48        )
 49        self.default_pipeline = {
 50            "backbone": self.backbone_step,
 51            "attention": self.attention_step,
 52            "masking": self.masking_step,
 53            "masked_backbone": self.backbone_step,
 54            "map_selection": self.selection_step,
 55        }
 56        phases = ["before_backbone"]
 57        phases += [f"after_{step}" for step in self.default_pipeline.keys()]
 58        self.hooks = {phase: [] for phase in phases}
 59        for phase, hook in hooks.items():
 60            if phase in self.default_pipeline.keys():
 61                # a total override of a step
 62                self.default_pipeline[phase] = hook(
 63                    self, self.default_pipeline[phase]
 64                )
 65            elif phase in phases:
 66                # a hook around a step
 67                self.hooks[phase].append(hook)
 68            else:
 69                raise ValueError(f"Unknown phase: {phase}")
 70        self.return_keys = return_keys
 71
 72    def _select_returned_maps(self, saliency_maps, targets):
 73        """Select the saliency maps that will be returned based on the targets
 74        and the stage."""
 75        if self.training:
 76            return self.train_map_select(saliency_maps, targets)
 77        else:
 78            return self._select_matching_maps(saliency_maps, targets)
 79
 80    def _select_matching_maps(self, saliency_maps, targets):
 81        """Get the saliency maps from the attention mechanism.
 82
 83        Depending on the stage and, masks as selected using a different
 84        procedure.
 85        """
 86        if self.training:
 87            return self.matching_map_select(saliency_maps, targets)
 88        else:
 89            return self.eval_map_select(saliency_maps, targets)
 90
 91    def backbone_step(self, x: dict) -> dict:
 92        """Forward pass of the backbone."""
 93        if (inp := x.get("masked_images", None)) is not None:
 94            x["logits_masked"] = self.backbone(inp)
 95            x["targets_masked"] = x["logits_masked"].argmax(dim=1)
 96        else:
 97            x["logits"] = self.backbone(x["images"])
 98            x["targets"] = x["logits"].argmax(dim=1)
 99        return x
100
101    def attention_step(self, x: dict) -> dict:
102        """Forward pass of the attention mechanism."""
103        x["maps"] = self.attention(**x)
104        return x
105
106    def masking_step(self, x: dict) -> dict:
107        """Apply the masks to the input tensor and get the output."""
108        x["masked_images"] = self.masking_procedure(
109            x["images"], self._select_matching_maps(x["maps"], x["targets"])
110        )
111        return x
112
113    def selection_step(self, x: dict) -> dict:
114        """Select the saliency maps that will be returned based on the targets
115        and the stage."""
116        x["masks"] = self._select_returned_maps(x["maps"], x["targets"])
117        return x
118
119    def run_hooks(self, phase: str, x: dict) -> dict:
120        """Run the hooks for the given phase."""
121        for hook in self.hooks[phase]:
122            x = hook(self, x)
123        return x
124
125    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
126        """Forward pass of the entire model.
127
128        Depending on the stage, different outputs are returned.
129        """
130        data = {"images": x, "training": self.training}
131        data = self.run_hooks("before_backbone", data)
132        for step_name, step in self.default_pipeline.items():
133            data = step(data)
134            data = self.run_hooks("after_" + step_name, data)
135
136        return {
137            return_key: data[return_key] for return_key in self.return_keys
138        }
139
140    def produce_map(
141        self, image: torch.Tensor, **kwargs
142    ) -> tuple[torch.Tensor, int]:
143        """Produce a saliency map for the given image."""
144        data = {"images": image.unsqueeze(0), "training": False}
145        data = self.run_hooks("before_backbone", data)
146        for step_name, step in self.default_pipeline.items():
147            data = step(data)
148            data = self.run_hooks("after_" + step_name, data)
149        if classes := kwargs.get("classes", False):
150            top_classes = torch.topk(data["logits"], classes)[1]
151            maps = data["maps"].squeeze()[top_classes]
152            return maps, top_classes
153        else:
154            map = data["masks"].squeeze()
155            return map, data["targets"].squeeze()
156
157    def produce_cdmaps(
158        self, image: torch.Tensor
159    ) -> tuple[list[torch.Tensor], int]:
160        """Produce a list of saliency maps for the given image."""
161        image = image.unsqueeze(0)
162        out = self.backbone(image)
163        maps = self.attention(image)
164        predictions = out.squeeze()
165        return maps, predictions
166
167    def get_predictions(self, x):
168        """Get the predictions of the model.
169
170        Useful for the measures.
171        """
172        return self.default_pipeline["backbone"]({"images": x})["logits"]

Perturbation-based Attention Mechanism for Explaining bLAck-box models (PAMELA)

PtamePipeline( backbone: torch.nn.modules.module.Module, attention: torch.nn.modules.module.Module, masking_procedure=<function norm_resize_mask>, train_map_select=<function random_map_select>, matching_map_select=<function single_map_select>, eval_map_select=None, backbone_eval=True, return_keys: list[str] = ['logits', 'masks', 'logits_masked', 'targets'], hooks: dict[str, Callable] = {}, **kwargs)
19    def __init__(
20        self,
21        backbone: nn.Module,
22        attention: nn.Module,
23        masking_procedure=norm_resize_mask,
24        train_map_select=random_map_select,
25        matching_map_select=single_map_select,
26        eval_map_select=None,
27        backbone_eval=True,
28        return_keys: list[str] = [
29            "logits",
30            "masks",
31            "logits_masked",
32            "targets",
33        ],
34        hooks: dict[str, Callable] = {},
35        **kwargs,
36    ):
37        super().__init__()
38        self.backbone = backbone
39        self.backbone.requires_grad_(False)
40        if backbone_eval:
41            self.backbone.eval()
42        self.attention = attention
43        self.masking_procedure = masking_procedure
44        self.train_map_select = train_map_select
45        self.matching_map_select = matching_map_select
46        self.eval_map_select = (
47            a if (a := eval_map_select) is not None else matching_map_select
48        )
49        self.default_pipeline = {
50            "backbone": self.backbone_step,
51            "attention": self.attention_step,
52            "masking": self.masking_step,
53            "masked_backbone": self.backbone_step,
54            "map_selection": self.selection_step,
55        }
56        phases = ["before_backbone"]
57        phases += [f"after_{step}" for step in self.default_pipeline.keys()]
58        self.hooks = {phase: [] for phase in phases}
59        for phase, hook in hooks.items():
60            if phase in self.default_pipeline.keys():
61                # a total override of a step
62                self.default_pipeline[phase] = hook(
63                    self, self.default_pipeline[phase]
64                )
65            elif phase in phases:
66                # a hook around a step
67                self.hooks[phase].append(hook)
68            else:
69                raise ValueError(f"Unknown phase: {phase}")
70        self.return_keys = return_keys

Initialize internal Module state, shared by both nn.Module and ScriptModule.

backbone
attention
masking_procedure
train_map_select
matching_map_select
eval_map_select
default_pipeline
hooks
return_keys
def backbone_step(self, x: dict) -> dict:
91    def backbone_step(self, x: dict) -> dict:
92        """Forward pass of the backbone."""
93        if (inp := x.get("masked_images", None)) is not None:
94            x["logits_masked"] = self.backbone(inp)
95            x["targets_masked"] = x["logits_masked"].argmax(dim=1)
96        else:
97            x["logits"] = self.backbone(x["images"])
98            x["targets"] = x["logits"].argmax(dim=1)
99        return x

Forward pass of the backbone.

def attention_step(self, x: dict) -> dict:
101    def attention_step(self, x: dict) -> dict:
102        """Forward pass of the attention mechanism."""
103        x["maps"] = self.attention(**x)
104        return x

Forward pass of the attention mechanism.

def masking_step(self, x: dict) -> dict:
106    def masking_step(self, x: dict) -> dict:
107        """Apply the masks to the input tensor and get the output."""
108        x["masked_images"] = self.masking_procedure(
109            x["images"], self._select_matching_maps(x["maps"], x["targets"])
110        )
111        return x

Apply the masks to the input tensor and get the output.

def selection_step(self, x: dict) -> dict:
113    def selection_step(self, x: dict) -> dict:
114        """Select the saliency maps that will be returned based on the targets
115        and the stage."""
116        x["masks"] = self._select_returned_maps(x["maps"], x["targets"])
117        return x

Select the saliency maps that will be returned based on the targets and the stage.

def run_hooks(self, phase: str, x: dict) -> dict:
119    def run_hooks(self, phase: str, x: dict) -> dict:
120        """Run the hooks for the given phase."""
121        for hook in self.hooks[phase]:
122            x = hook(self, x)
123        return x

Run the hooks for the given phase.

def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
125    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
126        """Forward pass of the entire model.
127
128        Depending on the stage, different outputs are returned.
129        """
130        data = {"images": x, "training": self.training}
131        data = self.run_hooks("before_backbone", data)
132        for step_name, step in self.default_pipeline.items():
133            data = step(data)
134            data = self.run_hooks("after_" + step_name, data)
135
136        return {
137            return_key: data[return_key] for return_key in self.return_keys
138        }

Forward pass of the entire model.

Depending on the stage, different outputs are returned.

def produce_map(self, image: torch.Tensor, **kwargs) -> tuple[torch.Tensor, int]:
140    def produce_map(
141        self, image: torch.Tensor, **kwargs
142    ) -> tuple[torch.Tensor, int]:
143        """Produce a saliency map for the given image."""
144        data = {"images": image.unsqueeze(0), "training": False}
145        data = self.run_hooks("before_backbone", data)
146        for step_name, step in self.default_pipeline.items():
147            data = step(data)
148            data = self.run_hooks("after_" + step_name, data)
149        if classes := kwargs.get("classes", False):
150            top_classes = torch.topk(data["logits"], classes)[1]
151            maps = data["maps"].squeeze()[top_classes]
152            return maps, top_classes
153        else:
154            map = data["masks"].squeeze()
155            return map, data["targets"].squeeze()

Produce a saliency map for the given image.

def produce_cdmaps(self, image: torch.Tensor) -> tuple[list[torch.Tensor], int]:
157    def produce_cdmaps(
158        self, image: torch.Tensor
159    ) -> tuple[list[torch.Tensor], int]:
160        """Produce a list of saliency maps for the given image."""
161        image = image.unsqueeze(0)
162        out = self.backbone(image)
163        maps = self.attention(image)
164        predictions = out.squeeze()
165        return maps, predictions

Produce a list of saliency maps for the given image.

def get_predictions(self, x):
167    def get_predictions(self, x):
168        """Get the predictions of the model.
169
170        Useful for the measures.
171        """
172        return self.default_pipeline["backbone"]({"images": x})["logits"]

Get the predictions of the model.

Useful for the measures.