ptame.models.components.ttame

  1from typing import List, Tuple, Type
  2
  3import torch
  4from torch import nn
  5from torchvision.models.feature_extraction import (
  6    create_feature_extractor,
  7    get_graph_node_names,
  8)
  9
 10from ptame.models.components.ptame_pipeline import PtamePipeline
 11from ptame.utils.map_printer import XAIModel
 12
 13
 14class TTAME(PtamePipeline, XAIModel):
 15    """Transformer-compatible Trainable Attention Mechanism for
 16    Explainability."""
 17
 18    def __init__(
 19        self,
 20        backbone: nn.Module,
 21        attention: nn.Module,
 22        masking_procedure,
 23    ):
 24        super().__init__(backbone, attention, masking_procedure)
 25
 26    def _forward_backbone(
 27        self, x: torch.Tensor
 28    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
 29        """Forward pass of the backbone (feature extractor)."""
 30        y, features = self.backbone(x)
 31        return y, features
 32
 33    def _forward_second(
 34        self, x: torch.Tensor, features: List[torch.Tensor], targets
 35    ):
 36        """Second forward pass, for the training stage only."""
 37        # Get the saliency maps from the attention mechanism
 38        saliency_maps = self.attention(features)
 39        # Apply the masks to the input tensor and get the output
 40        x_masked = self.masking_procedure(
 41            x, self._select_matching_maps(saliency_maps, targets)
 42        )
 43        y_masked, _ = self.backbone(x_masked)
 44        # Select the saliency maps which will be returned based on the targets and the stage
 45        saliency_maps = self._select_returned_maps(saliency_maps, targets)
 46        return y_masked, saliency_maps
 47
 48    def forward(self, x: torch.Tensor):
 49        """Forward pass of the entire model.
 50
 51        Depending on the stage, different outputs are returned.
 52        """
 53        y, features = self._forward_backbone(x.clone().detach())
 54        targets = y.argmax(dim=1)
 55        y_masked, maps = self._forward_second(x, features, targets)
 56        return {
 57            "logits": y,
 58            "logits_masked": y_masked,
 59            "targets": targets,
 60            "masks": maps,
 61        }
 62
 63    def get_predictions(self, x):
 64        """Get the predictions of the model.
 65
 66        Useful for the measures.
 67        """
 68        return self._forward_backbone(x)[0]
 69
 70    def produce_map(self, image: torch.Tensor) -> Tuple[torch.Tensor, int]:
 71        image = image.unsqueeze(0)
 72        out, features = self._forward_backbone(image)
 73        maps = self.attention(features)
 74        prediction = out[0].argmax().item()
 75        map = maps[0, prediction]
 76        return map, prediction
 77
 78    def produce_cdmaps(
 79        self, image: torch.Tensor
 80    ) -> Tuple[List[torch.Tensor], int]:
 81        image = image.unsqueeze(0)
 82        out, features = self._forward_backbone(image)
 83        maps = self.attention(features)
 84        predictions = out[0]
 85        return maps, predictions
 86
 87
 88class TTAMEBuilder:
 89    """Builder for the TTAME model."""
 90
 91    def __init__(
 92        self,
 93        backbone: nn.Module,
 94        attention: Type[nn.Module],
 95        masking_procedure,
 96        layers: List[str],
 97        input_dim: List[int] = None,
 98        num_classes: int = 1000,
 99        **kwargs,
100    ):
101        self.backbone = backbone
102        self.attention = attention
103        self.masking_procedure = masking_procedure
104        self.layers = layers
105        self.input_dim = input_dim
106        self.num_classes = num_classes
107        self.kwargs = kwargs
108
109    @torch.no_grad()
110    def build(self) -> TTAME:
111        """Build the TTAME model."""
112        # build the feature extractor
113        self._build_fx()
114        # build the attention mechanism
115        self._build_attention()
116        # build the model
117        return TTAME(
118            self.backbone,
119            self.attention,
120            self.masking_procedure,
121        )
122
123    @torch.no_grad()
124    def build_pipeline(self) -> TTAME:
125        """Build the TTAME model."""
126        # build the feature extractor
127        self._build_fx()
128        # build the attention mechanism
129        self._build_attention()
130
131        def ttame_backbone_step(pipeline, _default_step):
132            def backbone_step(x: dict) -> dict:
133                """Forward pass of the backbone."""
134                if (inp := x.get("masked_images", None)) is not None:
135                    x["logits_masked"], _ = pipeline.backbone(inp)
136                    x["targets_masked"] = x["logits_masked"].argmax(dim=1)
137                else:
138                    x["logits"], x["features"] = pipeline.backbone(x["images"])
139                    x["targets"] = x["logits"].argmax(dim=1)
140                return x
141
142            return backbone_step
143
144        self.kwargs["hooks"] = {
145            "backbone": ttame_backbone_step,
146            "masked_backbone": ttame_backbone_step,
147        } | self.kwargs.get("hooks", {})
148        # build the model
149        return PtamePipeline(
150            self.backbone,
151            self.attention,
152            self.masking_procedure,
153            **self.kwargs,
154        )
155
156    def _build_fx(self):
157        """Build feature extractor."""
158        # if no layers are specified, print layers and quit
159        train_names, eval_names = get_graph_node_names(self.backbone)
160        if self.layers == [] or self.layers is None:
161            print(train_names)
162            quit()
163        # get the output layer name
164        output = (train_names[-1], eval_names[-1])
165        if output[0] != output[1]:
166            print(
167                "WARNING! THIS MODEL HAS DIFFERENT OUTPUTS FOR TRAIN AND EVAL MODE"
168            )
169        self.output_name = output[0]
170        # get feature extractor
171        self.backbone = create_feature_extractor(
172            self.backbone, return_nodes=(self.layers + [self.output_name])
173        )
174        # Dry run to get number of channels of each layer for the attention mechanism
175        if self.input_dim is not None:
176            inp = torch.randn(self.input_dim)
177        else:
178            inp = torch.randn(1, 3, 224, 224)
179        self.backbone.eval()
180        outputs = self.backbone(inp)
181        outputs.pop(self.output_name)
182        features = outputs.values()
183        self.feature_size = [o.shape for o in features]
184        self.backbone.register_forward_hook(self._simplify_graph_outputs())
185
186    def _build_attention(self):
187        """Build the attention mechanism."""
188        # check if the model is a transformer
189        if self._is_transformer():
190            feature_size = [
191                torch.Size([2, ft[-1], 14, 14]) for ft in self.feature_size
192            ]
193            self.attention = self.attention(feature_size, self.num_classes)
194            self.attention.register_forward_pre_hook(
195                self._feature_adapter_vit_b_16(), with_kwargs=True
196            )
197        else:
198            self.attention = self.attention(
199                self.feature_size, self.num_classes
200            )
201
202    def _is_transformer(self) -> bool:
203        """Check if the model is a transformer.
204
205        Returns:
206            bool: True if the model is a transformer, False otherwise.
207        """
208        for module in self.backbone.modules():
209            if isinstance(module, nn.MultiheadAttention):
210                return True
211        return False
212
213    def _simplify_graph_outputs(self):
214        """Returns a hook to simplify the outputs of the GraphModule feature
215        extractor."""
216
217        def hook(
218            module, inputs, outputs
219        ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
220            y: torch.Tensor = outputs.pop(self.output_name)
221            features = list(outputs.values())
222            return y, features
223
224        return hook
225
226    def _feature_adapter_vit_b_16(self):
227        """Returns a hook to adapt the features of the Vision Transformer
228        model."""
229
230        def hook(module, args, kwargs) -> List[torch.Tensor]:
231            seq_list = kwargs["features"]
232            # discard class token
233            seq_list = [seq[:, 1:, :] for seq in seq_list]
234            # reshape
235            seq_list = [
236                seq.reshape(seq.size(0), 14, 14, seq.size(2))
237                for seq in seq_list
238            ]
239            # bring channels after batch dimension
240            seq_list = [
241                seq.transpose(2, 3).transpose(1, 2) for seq in seq_list
242            ]
243            kwargs["features"] = seq_list
244            return args, kwargs
245
246        return hook
247
248
249def build_ttame(**kwargs) -> TTAME:
250    """Build the TTAME model.
251
252    Args:
253        Same as TTAMEBuilder.
254    Returns:
255        TTAME: The TTAME model.
256    """
257    builder = TTAMEBuilder(**kwargs)
258    return builder.build()
259
260
261def build_ttame_pipeline(**kwargs) -> TTAME:
262    """Build the TTAME model.
263
264    Args:
265        Same as TTAMEBuilder.
266    Returns:
267        TTAME: The TTAME model.
268    """
269    builder = TTAMEBuilder(**kwargs)
270    return builder.build_pipeline()
class TTAME(ptame.models.components.ptame_pipeline.PtamePipeline, ptame.utils.map_printer.XAIModel):
15class TTAME(PtamePipeline, XAIModel):
16    """Transformer-compatible Trainable Attention Mechanism for
17    Explainability."""
18
19    def __init__(
20        self,
21        backbone: nn.Module,
22        attention: nn.Module,
23        masking_procedure,
24    ):
25        super().__init__(backbone, attention, masking_procedure)
26
27    def _forward_backbone(
28        self, x: torch.Tensor
29    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
30        """Forward pass of the backbone (feature extractor)."""
31        y, features = self.backbone(x)
32        return y, features
33
34    def _forward_second(
35        self, x: torch.Tensor, features: List[torch.Tensor], targets
36    ):
37        """Second forward pass, for the training stage only."""
38        # Get the saliency maps from the attention mechanism
39        saliency_maps = self.attention(features)
40        # Apply the masks to the input tensor and get the output
41        x_masked = self.masking_procedure(
42            x, self._select_matching_maps(saliency_maps, targets)
43        )
44        y_masked, _ = self.backbone(x_masked)
45        # Select the saliency maps which will be returned based on the targets and the stage
46        saliency_maps = self._select_returned_maps(saliency_maps, targets)
47        return y_masked, saliency_maps
48
49    def forward(self, x: torch.Tensor):
50        """Forward pass of the entire model.
51
52        Depending on the stage, different outputs are returned.
53        """
54        y, features = self._forward_backbone(x.clone().detach())
55        targets = y.argmax(dim=1)
56        y_masked, maps = self._forward_second(x, features, targets)
57        return {
58            "logits": y,
59            "logits_masked": y_masked,
60            "targets": targets,
61            "masks": maps,
62        }
63
64    def get_predictions(self, x):
65        """Get the predictions of the model.
66
67        Useful for the measures.
68        """
69        return self._forward_backbone(x)[0]
70
71    def produce_map(self, image: torch.Tensor) -> Tuple[torch.Tensor, int]:
72        image = image.unsqueeze(0)
73        out, features = self._forward_backbone(image)
74        maps = self.attention(features)
75        prediction = out[0].argmax().item()
76        map = maps[0, prediction]
77        return map, prediction
78
79    def produce_cdmaps(
80        self, image: torch.Tensor
81    ) -> Tuple[List[torch.Tensor], int]:
82        image = image.unsqueeze(0)
83        out, features = self._forward_backbone(image)
84        maps = self.attention(features)
85        predictions = out[0]
86        return maps, predictions

Transformer-compatible Trainable Attention Mechanism for Explainability.

TTAME( backbone: torch.nn.modules.module.Module, attention: torch.nn.modules.module.Module, masking_procedure)
19    def __init__(
20        self,
21        backbone: nn.Module,
22        attention: nn.Module,
23        masking_procedure,
24    ):
25        super().__init__(backbone, attention, masking_procedure)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

def forward(self, x: torch.Tensor):
49    def forward(self, x: torch.Tensor):
50        """Forward pass of the entire model.
51
52        Depending on the stage, different outputs are returned.
53        """
54        y, features = self._forward_backbone(x.clone().detach())
55        targets = y.argmax(dim=1)
56        y_masked, maps = self._forward_second(x, features, targets)
57        return {
58            "logits": y,
59            "logits_masked": y_masked,
60            "targets": targets,
61            "masks": maps,
62        }

Forward pass of the entire model.

Depending on the stage, different outputs are returned.

def get_predictions(self, x):
64    def get_predictions(self, x):
65        """Get the predictions of the model.
66
67        Useful for the measures.
68        """
69        return self._forward_backbone(x)[0]

Get the predictions of the model.

Useful for the measures.

def produce_map(self, image: torch.Tensor) -> Tuple[torch.Tensor, int]:
71    def produce_map(self, image: torch.Tensor) -> Tuple[torch.Tensor, int]:
72        image = image.unsqueeze(0)
73        out, features = self._forward_backbone(image)
74        maps = self.attention(features)
75        prediction = out[0].argmax().item()
76        map = maps[0, prediction]
77        return map, prediction

Produce a saliency map for the given image.

def produce_cdmaps(self, image: torch.Tensor) -> Tuple[List[torch.Tensor], int]:
79    def produce_cdmaps(
80        self, image: torch.Tensor
81    ) -> Tuple[List[torch.Tensor], int]:
82        image = image.unsqueeze(0)
83        out, features = self._forward_backbone(image)
84        maps = self.attention(features)
85        predictions = out[0]
86        return maps, predictions

Produce a list of saliency maps for the given image.

class TTAMEBuilder:
 89class TTAMEBuilder:
 90    """Builder for the TTAME model."""
 91
 92    def __init__(
 93        self,
 94        backbone: nn.Module,
 95        attention: Type[nn.Module],
 96        masking_procedure,
 97        layers: List[str],
 98        input_dim: List[int] = None,
 99        num_classes: int = 1000,
100        **kwargs,
101    ):
102        self.backbone = backbone
103        self.attention = attention
104        self.masking_procedure = masking_procedure
105        self.layers = layers
106        self.input_dim = input_dim
107        self.num_classes = num_classes
108        self.kwargs = kwargs
109
110    @torch.no_grad()
111    def build(self) -> TTAME:
112        """Build the TTAME model."""
113        # build the feature extractor
114        self._build_fx()
115        # build the attention mechanism
116        self._build_attention()
117        # build the model
118        return TTAME(
119            self.backbone,
120            self.attention,
121            self.masking_procedure,
122        )
123
124    @torch.no_grad()
125    def build_pipeline(self) -> TTAME:
126        """Build the TTAME model."""
127        # build the feature extractor
128        self._build_fx()
129        # build the attention mechanism
130        self._build_attention()
131
132        def ttame_backbone_step(pipeline, _default_step):
133            def backbone_step(x: dict) -> dict:
134                """Forward pass of the backbone."""
135                if (inp := x.get("masked_images", None)) is not None:
136                    x["logits_masked"], _ = pipeline.backbone(inp)
137                    x["targets_masked"] = x["logits_masked"].argmax(dim=1)
138                else:
139                    x["logits"], x["features"] = pipeline.backbone(x["images"])
140                    x["targets"] = x["logits"].argmax(dim=1)
141                return x
142
143            return backbone_step
144
145        self.kwargs["hooks"] = {
146            "backbone": ttame_backbone_step,
147            "masked_backbone": ttame_backbone_step,
148        } | self.kwargs.get("hooks", {})
149        # build the model
150        return PtamePipeline(
151            self.backbone,
152            self.attention,
153            self.masking_procedure,
154            **self.kwargs,
155        )
156
157    def _build_fx(self):
158        """Build feature extractor."""
159        # if no layers are specified, print layers and quit
160        train_names, eval_names = get_graph_node_names(self.backbone)
161        if self.layers == [] or self.layers is None:
162            print(train_names)
163            quit()
164        # get the output layer name
165        output = (train_names[-1], eval_names[-1])
166        if output[0] != output[1]:
167            print(
168                "WARNING! THIS MODEL HAS DIFFERENT OUTPUTS FOR TRAIN AND EVAL MODE"
169            )
170        self.output_name = output[0]
171        # get feature extractor
172        self.backbone = create_feature_extractor(
173            self.backbone, return_nodes=(self.layers + [self.output_name])
174        )
175        # Dry run to get number of channels of each layer for the attention mechanism
176        if self.input_dim is not None:
177            inp = torch.randn(self.input_dim)
178        else:
179            inp = torch.randn(1, 3, 224, 224)
180        self.backbone.eval()
181        outputs = self.backbone(inp)
182        outputs.pop(self.output_name)
183        features = outputs.values()
184        self.feature_size = [o.shape for o in features]
185        self.backbone.register_forward_hook(self._simplify_graph_outputs())
186
187    def _build_attention(self):
188        """Build the attention mechanism."""
189        # check if the model is a transformer
190        if self._is_transformer():
191            feature_size = [
192                torch.Size([2, ft[-1], 14, 14]) for ft in self.feature_size
193            ]
194            self.attention = self.attention(feature_size, self.num_classes)
195            self.attention.register_forward_pre_hook(
196                self._feature_adapter_vit_b_16(), with_kwargs=True
197            )
198        else:
199            self.attention = self.attention(
200                self.feature_size, self.num_classes
201            )
202
203    def _is_transformer(self) -> bool:
204        """Check if the model is a transformer.
205
206        Returns:
207            bool: True if the model is a transformer, False otherwise.
208        """
209        for module in self.backbone.modules():
210            if isinstance(module, nn.MultiheadAttention):
211                return True
212        return False
213
214    def _simplify_graph_outputs(self):
215        """Returns a hook to simplify the outputs of the GraphModule feature
216        extractor."""
217
218        def hook(
219            module, inputs, outputs
220        ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
221            y: torch.Tensor = outputs.pop(self.output_name)
222            features = list(outputs.values())
223            return y, features
224
225        return hook
226
227    def _feature_adapter_vit_b_16(self):
228        """Returns a hook to adapt the features of the Vision Transformer
229        model."""
230
231        def hook(module, args, kwargs) -> List[torch.Tensor]:
232            seq_list = kwargs["features"]
233            # discard class token
234            seq_list = [seq[:, 1:, :] for seq in seq_list]
235            # reshape
236            seq_list = [
237                seq.reshape(seq.size(0), 14, 14, seq.size(2))
238                for seq in seq_list
239            ]
240            # bring channels after batch dimension
241            seq_list = [
242                seq.transpose(2, 3).transpose(1, 2) for seq in seq_list
243            ]
244            kwargs["features"] = seq_list
245            return args, kwargs
246
247        return hook

Builder for the TTAME model.

TTAMEBuilder( backbone: torch.nn.modules.module.Module, attention: Type[torch.nn.modules.module.Module], masking_procedure, layers: List[str], input_dim: List[int] = None, num_classes: int = 1000, **kwargs)
 92    def __init__(
 93        self,
 94        backbone: nn.Module,
 95        attention: Type[nn.Module],
 96        masking_procedure,
 97        layers: List[str],
 98        input_dim: List[int] = None,
 99        num_classes: int = 1000,
100        **kwargs,
101    ):
102        self.backbone = backbone
103        self.attention = attention
104        self.masking_procedure = masking_procedure
105        self.layers = layers
106        self.input_dim = input_dim
107        self.num_classes = num_classes
108        self.kwargs = kwargs
backbone
attention
masking_procedure
layers
input_dim
num_classes
kwargs
@torch.no_grad()
def build(self) -> TTAME:
110    @torch.no_grad()
111    def build(self) -> TTAME:
112        """Build the TTAME model."""
113        # build the feature extractor
114        self._build_fx()
115        # build the attention mechanism
116        self._build_attention()
117        # build the model
118        return TTAME(
119            self.backbone,
120            self.attention,
121            self.masking_procedure,
122        )

Build the TTAME model.

@torch.no_grad()
def build_pipeline(self) -> TTAME:
124    @torch.no_grad()
125    def build_pipeline(self) -> TTAME:
126        """Build the TTAME model."""
127        # build the feature extractor
128        self._build_fx()
129        # build the attention mechanism
130        self._build_attention()
131
132        def ttame_backbone_step(pipeline, _default_step):
133            def backbone_step(x: dict) -> dict:
134                """Forward pass of the backbone."""
135                if (inp := x.get("masked_images", None)) is not None:
136                    x["logits_masked"], _ = pipeline.backbone(inp)
137                    x["targets_masked"] = x["logits_masked"].argmax(dim=1)
138                else:
139                    x["logits"], x["features"] = pipeline.backbone(x["images"])
140                    x["targets"] = x["logits"].argmax(dim=1)
141                return x
142
143            return backbone_step
144
145        self.kwargs["hooks"] = {
146            "backbone": ttame_backbone_step,
147            "masked_backbone": ttame_backbone_step,
148        } | self.kwargs.get("hooks", {})
149        # build the model
150        return PtamePipeline(
151            self.backbone,
152            self.attention,
153            self.masking_procedure,
154            **self.kwargs,
155        )

Build the TTAME model.

def build_ttame(**kwargs) -> TTAME:
250def build_ttame(**kwargs) -> TTAME:
251    """Build the TTAME model.
252
253    Args:
254        Same as TTAMEBuilder.
255    Returns:
256        TTAME: The TTAME model.
257    """
258    builder = TTAMEBuilder(**kwargs)
259    return builder.build()

Build the TTAME model.

Args: Same as TTAMEBuilder. Returns: TTAME: The TTAME model.

def build_ttame_pipeline(**kwargs) -> TTAME:
262def build_ttame_pipeline(**kwargs) -> TTAME:
263    """Build the TTAME model.
264
265    Args:
266        Same as TTAMEBuilder.
267    Returns:
268        TTAME: The TTAME model.
269    """
270    builder = TTAMEBuilder(**kwargs)
271    return builder.build_pipeline()

Build the TTAME model.

Args: Same as TTAMEBuilder. Returns: TTAME: The TTAME model.