ptame.models.components.attention.ptame_attention

  1import math
  2
  3import torch
  4from torch import nn
  5from torchvision.models import resnet18
  6from torchvision.models.feature_extraction import (
  7    create_feature_extractor,
  8    get_graph_node_names,
  9)
 10
 11from ptame.utils.masking_utils import minmax_4d
 12
 13
 14class PTAMEAttention(nn.Module):
 15    """PTAMEAttention is a class that represents the attention mechanism based
 16    on a feature extractor and P-TAME."""
 17
 18    def __init__(
 19        self,
 20        model: dict[str, list | nn.Module] = {
 21            "model": resnet18(weights="DEFAULT"),
 22            "layers": ["layer1", "layer2", "layer3", "layer4"],
 23        },
 24        input_dim: list[int] | None = [2, 3, 224, 224],
 25        scale_up: nn.Module = nn.Upsample(scale_factor=2, mode="bilinear"),
 26        activation: nn.Module = nn.ReLU(),
 27        fuser_bias: bool = True,
 28        num_classes: int = 1000,
 29        unfreeze: int = 0,
 30        cascading: bool = True,
 31        only_layer_train: bool = False,
 32        normalizers: list = [torch.sigmoid, minmax_4d],
 33    ):
 34        super().__init__()
 35        self.layer_names = model["layers"]
 36        fx_dict = self._make_fx(**model, input_dim=input_dim)
 37        self.attention = fx_dict["attention"]
 38        self.attention.requires_grad_(False).eval()
 39        self.ft_size = fx_dict["ft_size"]
 40        self.output_name = fx_dict["output_name"]
 41        self.scale_up = scale_up
 42        self.num_classes = num_classes
 43
 44        self.resolution = self.ft_size[0][-1]
 45        channels = [ft[0] for ft in self.ft_size]
 46        self.channels = channels
 47        self.convs = nn.ModuleList(
 48            [
 49                nn.Conv2d(
 50                    in_channels=c,
 51                    out_channels=c,
 52                    kernel_size=1,
 53                    padding=0,
 54                    bias=True,
 55                )
 56                for c in channels
 57            ]
 58        )
 59        self.bns = nn.ModuleList([nn.BatchNorm2d(c) for c in channels])
 60        self.act = activation
 61        self.fuser = nn.Conv2d(
 62            in_channels=sum(channels),
 63            out_channels=self.num_classes,
 64            kernel_size=1,
 65            padding=0,
 66            bias=fuser_bias,
 67        )
 68
 69        self.normalizers = normalizers
 70        self.trainable_layers = [self.convs, self.bns, self.fuser]
 71
 72    @staticmethod
 73    def list_processor(
 74        ops_list: nn.ModuleList, feature_list: list[torch.Tensor]
 75    ) -> list[torch.Tensor]:
 76        """Apply a list of operations to a list of feature maps.
 77
 78        Args:
 79            ops_list (nn.ModuleList): List of operations to apply.
 80            feature_list (list[torch.Tensor]): List of feature maps.
 81
 82        Returns:
 83            list[torch.Tensor]: List of feature maps with the operations applied.
 84        """
 85        return [op(feature) for op, feature in zip(ops_list, feature_list)]
 86
 87    @staticmethod
 88    def skip_connection(
 89        a: list[torch.Tensor], b: list[torch.Tensor]
 90    ) -> list[torch.Tensor]:
 91        """Add the feature maps of two lists.
 92
 93        Args:
 94            a (list[torch.Tensor]): First list of feature maps.
 95            b (list[torch.Tensor]): Second list of feature maps.
 96
 97        Returns:
 98            list[torch.Tensor]: List of feature maps with the two input lists added.
 99        """
100        return [a + b for a, b in zip(a, b)]
101
102    def list_activation(
103        self, feature_list: list[torch.Tensor]
104    ) -> list[torch.Tensor]:
105        """Apply the activation function to the feature maps.
106
107        Args:
108            feature_list (list[torch.Tensor]): List of feature maps.
109
110        Returns:
111            list[torch.Tensor]: List of feature maps with the activation function applied.
112        """
113        return [self.act(feature) for feature in feature_list]
114
115    def upscale(self, feature_list: list[torch.Tensor]) -> list[torch.Tensor]:
116        """Upscale the feature maps to the original resolution.
117
118        Args:
119            feature_list (list[torch.Tensor]): List of feature maps.
120
121        Returns:
122            list[torch.Tensor]: List of upscaled feature maps.
123        """
124
125        def up(map, times):
126            return up(self.scale_up(map), times - 1) if times > 0 else map
127
128        return [
129            up(
130                feature,
131                int(math.log(self.resolution / feature.shape[-1], 2)),
132            )
133            for feature in feature_list
134        ]
135
136    def forward(
137        self,
138        images: torch.Tensor,
139        with_outputs=False,
140        return_cms=False,
141        **kwargs,
142    ) -> torch.Tensor:
143        """Overwrite of attention forward method to include explanation
144        head."""
145        y, ft_maps = self.attention(images)
146
147        class_maps = self.list_processor(self.convs, ft_maps)
148        class_maps = self.list_processor(self.bns, class_maps)
149        class_maps = self.skip_connection(class_maps, ft_maps)
150        class_maps = self.list_activation(class_maps)
151        class_maps = self.upscale(class_maps)
152        class_maps = torch.cat(class_maps, 1)
153        if return_cms:
154            return class_maps
155        c = self.fuser(class_maps)
156        if with_outputs:
157            return y, c
158        if self.training or self.normalizers[1] is None:
159            return self.normalizers[0](c)
160        else:
161            return self.normalizers[1](c)
162
163    def _make_fx(self, model, layers, input_dim) -> dict:
164        """Create the feature extractor from the model."""
165        train_names, eval_names = get_graph_node_names(model)
166        if not layers:
167            print(train_names)
168            quit()
169        output = (train_names[-1], eval_names[-1])
170        if output[0] != output[1]:
171            print(
172                "WARNING! THIS MODEL HAS DIFFERENT OUTPUTS FOR TRAIN AND EVAL MODE"
173            )
174        output_name = output[0]
175        attention = create_feature_extractor(
176            model, return_nodes=layers + [output_name]
177        )
178        if input_dim is not None:
179            inp = torch.randn(input_dim)
180        else:
181            inp = torch.randn(2, 3, 224, 224)
182        attention.eval()
183        outputs = attention(inp)
184        outputs.pop(output_name)
185        features = outputs.values()
186        feature_size = [o.shape[1:] for o in features]
187
188        def hook(
189            module, inputs, outputs
190        ) -> tuple[torch.Tensor, list[torch.Tensor]]:
191            y: torch.Tensor = outputs.pop(output_name)
192            features = list(outputs.values())
193            return y, features
194
195        attention.register_forward_hook(hook)
196        return {
197            "attention": attention,
198            "ft_size": feature_size,
199            "output_name": output_name,
200        }
201
202    @torch.no_grad()
203    def get_contributions(self):
204        """Calculate the contributions of the attention mechanism."""
205        for name, param in self.fuser.named_parameters():
206            if "weight" in name:
207                weights = param.squeeze()
208        channels = self.channels.copy()
209        contribs = torch.stack(
210            [
211                a.sum(dim=1)
212                for a in weights.softmax(dim=1).split(channels, dim=1)
213            ],
214            dim=1,
215        )
216        return self.layer_names, contribs
217
218
219def minmax_compose(fn):
220    def wrapper(*args, **kwargs):
221        return minmax_4d(fn(*args, **kwargs))
222
223    return wrapper
class PTAMEAttention(torch.nn.modules.module.Module):
 15class PTAMEAttention(nn.Module):
 16    """PTAMEAttention is a class that represents the attention mechanism based
 17    on a feature extractor and P-TAME."""
 18
 19    def __init__(
 20        self,
 21        model: dict[str, list | nn.Module] = {
 22            "model": resnet18(weights="DEFAULT"),
 23            "layers": ["layer1", "layer2", "layer3", "layer4"],
 24        },
 25        input_dim: list[int] | None = [2, 3, 224, 224],
 26        scale_up: nn.Module = nn.Upsample(scale_factor=2, mode="bilinear"),
 27        activation: nn.Module = nn.ReLU(),
 28        fuser_bias: bool = True,
 29        num_classes: int = 1000,
 30        unfreeze: int = 0,
 31        cascading: bool = True,
 32        only_layer_train: bool = False,
 33        normalizers: list = [torch.sigmoid, minmax_4d],
 34    ):
 35        super().__init__()
 36        self.layer_names = model["layers"]
 37        fx_dict = self._make_fx(**model, input_dim=input_dim)
 38        self.attention = fx_dict["attention"]
 39        self.attention.requires_grad_(False).eval()
 40        self.ft_size = fx_dict["ft_size"]
 41        self.output_name = fx_dict["output_name"]
 42        self.scale_up = scale_up
 43        self.num_classes = num_classes
 44
 45        self.resolution = self.ft_size[0][-1]
 46        channels = [ft[0] for ft in self.ft_size]
 47        self.channels = channels
 48        self.convs = nn.ModuleList(
 49            [
 50                nn.Conv2d(
 51                    in_channels=c,
 52                    out_channels=c,
 53                    kernel_size=1,
 54                    padding=0,
 55                    bias=True,
 56                )
 57                for c in channels
 58            ]
 59        )
 60        self.bns = nn.ModuleList([nn.BatchNorm2d(c) for c in channels])
 61        self.act = activation
 62        self.fuser = nn.Conv2d(
 63            in_channels=sum(channels),
 64            out_channels=self.num_classes,
 65            kernel_size=1,
 66            padding=0,
 67            bias=fuser_bias,
 68        )
 69
 70        self.normalizers = normalizers
 71        self.trainable_layers = [self.convs, self.bns, self.fuser]
 72
 73    @staticmethod
 74    def list_processor(
 75        ops_list: nn.ModuleList, feature_list: list[torch.Tensor]
 76    ) -> list[torch.Tensor]:
 77        """Apply a list of operations to a list of feature maps.
 78
 79        Args:
 80            ops_list (nn.ModuleList): List of operations to apply.
 81            feature_list (list[torch.Tensor]): List of feature maps.
 82
 83        Returns:
 84            list[torch.Tensor]: List of feature maps with the operations applied.
 85        """
 86        return [op(feature) for op, feature in zip(ops_list, feature_list)]
 87
 88    @staticmethod
 89    def skip_connection(
 90        a: list[torch.Tensor], b: list[torch.Tensor]
 91    ) -> list[torch.Tensor]:
 92        """Add the feature maps of two lists.
 93
 94        Args:
 95            a (list[torch.Tensor]): First list of feature maps.
 96            b (list[torch.Tensor]): Second list of feature maps.
 97
 98        Returns:
 99            list[torch.Tensor]: List of feature maps with the two input lists added.
100        """
101        return [a + b for a, b in zip(a, b)]
102
103    def list_activation(
104        self, feature_list: list[torch.Tensor]
105    ) -> list[torch.Tensor]:
106        """Apply the activation function to the feature maps.
107
108        Args:
109            feature_list (list[torch.Tensor]): List of feature maps.
110
111        Returns:
112            list[torch.Tensor]: List of feature maps with the activation function applied.
113        """
114        return [self.act(feature) for feature in feature_list]
115
116    def upscale(self, feature_list: list[torch.Tensor]) -> list[torch.Tensor]:
117        """Upscale the feature maps to the original resolution.
118
119        Args:
120            feature_list (list[torch.Tensor]): List of feature maps.
121
122        Returns:
123            list[torch.Tensor]: List of upscaled feature maps.
124        """
125
126        def up(map, times):
127            return up(self.scale_up(map), times - 1) if times > 0 else map
128
129        return [
130            up(
131                feature,
132                int(math.log(self.resolution / feature.shape[-1], 2)),
133            )
134            for feature in feature_list
135        ]
136
137    def forward(
138        self,
139        images: torch.Tensor,
140        with_outputs=False,
141        return_cms=False,
142        **kwargs,
143    ) -> torch.Tensor:
144        """Overwrite of attention forward method to include explanation
145        head."""
146        y, ft_maps = self.attention(images)
147
148        class_maps = self.list_processor(self.convs, ft_maps)
149        class_maps = self.list_processor(self.bns, class_maps)
150        class_maps = self.skip_connection(class_maps, ft_maps)
151        class_maps = self.list_activation(class_maps)
152        class_maps = self.upscale(class_maps)
153        class_maps = torch.cat(class_maps, 1)
154        if return_cms:
155            return class_maps
156        c = self.fuser(class_maps)
157        if with_outputs:
158            return y, c
159        if self.training or self.normalizers[1] is None:
160            return self.normalizers[0](c)
161        else:
162            return self.normalizers[1](c)
163
164    def _make_fx(self, model, layers, input_dim) -> dict:
165        """Create the feature extractor from the model."""
166        train_names, eval_names = get_graph_node_names(model)
167        if not layers:
168            print(train_names)
169            quit()
170        output = (train_names[-1], eval_names[-1])
171        if output[0] != output[1]:
172            print(
173                "WARNING! THIS MODEL HAS DIFFERENT OUTPUTS FOR TRAIN AND EVAL MODE"
174            )
175        output_name = output[0]
176        attention = create_feature_extractor(
177            model, return_nodes=layers + [output_name]
178        )
179        if input_dim is not None:
180            inp = torch.randn(input_dim)
181        else:
182            inp = torch.randn(2, 3, 224, 224)
183        attention.eval()
184        outputs = attention(inp)
185        outputs.pop(output_name)
186        features = outputs.values()
187        feature_size = [o.shape[1:] for o in features]
188
189        def hook(
190            module, inputs, outputs
191        ) -> tuple[torch.Tensor, list[torch.Tensor]]:
192            y: torch.Tensor = outputs.pop(output_name)
193            features = list(outputs.values())
194            return y, features
195
196        attention.register_forward_hook(hook)
197        return {
198            "attention": attention,
199            "ft_size": feature_size,
200            "output_name": output_name,
201        }
202
203    @torch.no_grad()
204    def get_contributions(self):
205        """Calculate the contributions of the attention mechanism."""
206        for name, param in self.fuser.named_parameters():
207            if "weight" in name:
208                weights = param.squeeze()
209        channels = self.channels.copy()
210        contribs = torch.stack(
211            [
212                a.sum(dim=1)
213                for a in weights.softmax(dim=1).split(channels, dim=1)
214            ],
215            dim=1,
216        )
217        return self.layer_names, contribs

PTAMEAttention is a class that represents the attention mechanism based on a feature extractor and P-TAME.

PTAMEAttention( model: dict[str, list | torch.nn.modules.module.Module] = {'model': ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer2): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer3): Sequential( (0): BasicBlock( (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer4): Sequential( (0): BasicBlock( (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=512, out_features=1000, bias=True) ), 'layers': ['layer1', 'layer2', 'layer3', 'layer4']}, input_dim: list[int] | None = [2, 3, 224, 224], scale_up: torch.nn.modules.module.Module = Upsample(scale_factor=2.0, mode='bilinear'), activation: torch.nn.modules.module.Module = ReLU(), fuser_bias: bool = True, num_classes: int = 1000, unfreeze: int = 0, cascading: bool = True, only_layer_train: bool = False, normalizers: list = [<built-in method sigmoid of type object>, <function minmax_4d>])
19    def __init__(
20        self,
21        model: dict[str, list | nn.Module] = {
22            "model": resnet18(weights="DEFAULT"),
23            "layers": ["layer1", "layer2", "layer3", "layer4"],
24        },
25        input_dim: list[int] | None = [2, 3, 224, 224],
26        scale_up: nn.Module = nn.Upsample(scale_factor=2, mode="bilinear"),
27        activation: nn.Module = nn.ReLU(),
28        fuser_bias: bool = True,
29        num_classes: int = 1000,
30        unfreeze: int = 0,
31        cascading: bool = True,
32        only_layer_train: bool = False,
33        normalizers: list = [torch.sigmoid, minmax_4d],
34    ):
35        super().__init__()
36        self.layer_names = model["layers"]
37        fx_dict = self._make_fx(**model, input_dim=input_dim)
38        self.attention = fx_dict["attention"]
39        self.attention.requires_grad_(False).eval()
40        self.ft_size = fx_dict["ft_size"]
41        self.output_name = fx_dict["output_name"]
42        self.scale_up = scale_up
43        self.num_classes = num_classes
44
45        self.resolution = self.ft_size[0][-1]
46        channels = [ft[0] for ft in self.ft_size]
47        self.channels = channels
48        self.convs = nn.ModuleList(
49            [
50                nn.Conv2d(
51                    in_channels=c,
52                    out_channels=c,
53                    kernel_size=1,
54                    padding=0,
55                    bias=True,
56                )
57                for c in channels
58            ]
59        )
60        self.bns = nn.ModuleList([nn.BatchNorm2d(c) for c in channels])
61        self.act = activation
62        self.fuser = nn.Conv2d(
63            in_channels=sum(channels),
64            out_channels=self.num_classes,
65            kernel_size=1,
66            padding=0,
67            bias=fuser_bias,
68        )
69
70        self.normalizers = normalizers
71        self.trainable_layers = [self.convs, self.bns, self.fuser]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

layer_names
attention
ft_size
output_name
scale_up
num_classes
resolution
channels
convs
bns
act
fuser
normalizers
trainable_layers
@staticmethod
def list_processor( ops_list: torch.nn.modules.container.ModuleList, feature_list: list[torch.Tensor]) -> list[torch.Tensor]:
73    @staticmethod
74    def list_processor(
75        ops_list: nn.ModuleList, feature_list: list[torch.Tensor]
76    ) -> list[torch.Tensor]:
77        """Apply a list of operations to a list of feature maps.
78
79        Args:
80            ops_list (nn.ModuleList): List of operations to apply.
81            feature_list (list[torch.Tensor]): List of feature maps.
82
83        Returns:
84            list[torch.Tensor]: List of feature maps with the operations applied.
85        """
86        return [op(feature) for op, feature in zip(ops_list, feature_list)]

Apply a list of operations to a list of feature maps.

Args: ops_list (nn.ModuleList): List of operations to apply. feature_list (list[torch.Tensor]): List of feature maps.

Returns: list[torch.Tensor]: List of feature maps with the operations applied.

@staticmethod
def skip_connection(a: list[torch.Tensor], b: list[torch.Tensor]) -> list[torch.Tensor]:
 88    @staticmethod
 89    def skip_connection(
 90        a: list[torch.Tensor], b: list[torch.Tensor]
 91    ) -> list[torch.Tensor]:
 92        """Add the feature maps of two lists.
 93
 94        Args:
 95            a (list[torch.Tensor]): First list of feature maps.
 96            b (list[torch.Tensor]): Second list of feature maps.
 97
 98        Returns:
 99            list[torch.Tensor]: List of feature maps with the two input lists added.
100        """
101        return [a + b for a, b in zip(a, b)]

Add the feature maps of two lists.

Args: a (list[torch.Tensor]): First list of feature maps. b (list[torch.Tensor]): Second list of feature maps.

Returns: list[torch.Tensor]: List of feature maps with the two input lists added.

def list_activation(self, feature_list: list[torch.Tensor]) -> list[torch.Tensor]:
103    def list_activation(
104        self, feature_list: list[torch.Tensor]
105    ) -> list[torch.Tensor]:
106        """Apply the activation function to the feature maps.
107
108        Args:
109            feature_list (list[torch.Tensor]): List of feature maps.
110
111        Returns:
112            list[torch.Tensor]: List of feature maps with the activation function applied.
113        """
114        return [self.act(feature) for feature in feature_list]

Apply the activation function to the feature maps.

Args: feature_list (list[torch.Tensor]): List of feature maps.

Returns: list[torch.Tensor]: List of feature maps with the activation function applied.

def upscale(self, feature_list: list[torch.Tensor]) -> list[torch.Tensor]:
116    def upscale(self, feature_list: list[torch.Tensor]) -> list[torch.Tensor]:
117        """Upscale the feature maps to the original resolution.
118
119        Args:
120            feature_list (list[torch.Tensor]): List of feature maps.
121
122        Returns:
123            list[torch.Tensor]: List of upscaled feature maps.
124        """
125
126        def up(map, times):
127            return up(self.scale_up(map), times - 1) if times > 0 else map
128
129        return [
130            up(
131                feature,
132                int(math.log(self.resolution / feature.shape[-1], 2)),
133            )
134            for feature in feature_list
135        ]

Upscale the feature maps to the original resolution.

Args: feature_list (list[torch.Tensor]): List of feature maps.

Returns: list[torch.Tensor]: List of upscaled feature maps.

def forward( self, images: torch.Tensor, with_outputs=False, return_cms=False, **kwargs) -> torch.Tensor:
137    def forward(
138        self,
139        images: torch.Tensor,
140        with_outputs=False,
141        return_cms=False,
142        **kwargs,
143    ) -> torch.Tensor:
144        """Overwrite of attention forward method to include explanation
145        head."""
146        y, ft_maps = self.attention(images)
147
148        class_maps = self.list_processor(self.convs, ft_maps)
149        class_maps = self.list_processor(self.bns, class_maps)
150        class_maps = self.skip_connection(class_maps, ft_maps)
151        class_maps = self.list_activation(class_maps)
152        class_maps = self.upscale(class_maps)
153        class_maps = torch.cat(class_maps, 1)
154        if return_cms:
155            return class_maps
156        c = self.fuser(class_maps)
157        if with_outputs:
158            return y, c
159        if self.training or self.normalizers[1] is None:
160            return self.normalizers[0](c)
161        else:
162            return self.normalizers[1](c)

Overwrite of attention forward method to include explanation head.

@torch.no_grad()
def get_contributions(self):
203    @torch.no_grad()
204    def get_contributions(self):
205        """Calculate the contributions of the attention mechanism."""
206        for name, param in self.fuser.named_parameters():
207            if "weight" in name:
208                weights = param.squeeze()
209        channels = self.channels.copy()
210        contribs = torch.stack(
211            [
212                a.sum(dim=1)
213                for a in weights.softmax(dim=1).split(channels, dim=1)
214            ],
215            dim=1,
216        )
217        return self.layer_names, contribs

Calculate the contributions of the attention mechanism.

def minmax_compose(fn):
220def minmax_compose(fn):
221    def wrapper(*args, **kwargs):
222        return minmax_4d(fn(*args, **kwargs))
223
224    return wrapper