ptame.models.components.attention.ttame_attention

  1from typing import List
  2
  3import torch
  4from torch import nn
  5from torch.nn import functional as F
  6
  7from ptame.utils.masking_utils import minmax_4d
  8
  9
 10class TtameAttention(nn.Module):
 11    """TtameAttention is a class that represents the attention mechanism used
 12    in TTAME."""
 13
 14    def __init__(self, ft_size: List[torch.Size], num_classes=1000):
 15        super().__init__()
 16        feat_height = ft_size[0][2] if ft_size[0][2] <= 56 else 56
 17        self.resolution = feat_height
 18        self.interpolate = lambda inp: F.interpolate(
 19            inp,
 20            size=(feat_height, feat_height),
 21            mode="bilinear",
 22            align_corners=False,
 23        )
 24        in_channels_list = [o[1] for o in ft_size]
 25        self.channels = in_channels_list
 26        # noinspection PyTypeChecker
 27        self.convs = nn.ModuleList(
 28            [
 29                nn.Conv2d(
 30                    in_channels=in_channels,
 31                    out_channels=in_channels,
 32                    kernel_size=1,
 33                    padding=0,
 34                    bias=True,
 35                )
 36                for in_channels in in_channels_list
 37            ]
 38        )
 39        self.bn_channels = in_channels_list
 40        self.bns = nn.ModuleList(
 41            [nn.BatchNorm2d(channels) for channels in self.bn_channels]
 42        )
 43        self.relu = nn.ReLU()
 44        # for each extra layer we need 1000 more channels to input to the fuse
 45        # convolution
 46        fuse_channels = sum(in_channels_list)
 47        # noinspection PyTypeChecker
 48        self.fuser = nn.Conv2d(
 49            in_channels=fuse_channels,
 50            out_channels=num_classes,
 51            kernel_size=1,
 52            padding=0,
 53            bias=True,
 54        )
 55        self.num_classes = num_classes
 56
 57    def forward(self, features, **kwargs):
 58        """Forward pass of the attention mechanism."""
 59        # Fusion Strategy
 60        feature_maps = features
 61        if kwargs.get("opticam", False):
 62            return features[0]
 63        # Now all feature map sets are of the same HxW
 64        # conv
 65        class_maps = [
 66            op(feature) for op, feature in zip(self.convs, feature_maps)
 67        ]
 68        # batch norm
 69        class_maps = [op(feature) for op, feature in zip(self.bns, class_maps)]
 70        # add (skip connection)
 71        class_maps = [
 72            class_map + feature_map
 73            for class_map, feature_map in zip(class_maps, feature_maps)
 74        ]
 75        # activation
 76        class_maps = [self.relu(class_map) for class_map in class_maps]
 77        # upscale
 78        class_maps = [self.interpolate(feature) for feature in class_maps]
 79        # concat
 80        class_maps = torch.cat(class_maps, 1)
 81        # fuse into num_classes channels
 82        c = self.fuser(class_maps)  # batch_size x1xWxH
 83        if not self.training:
 84            return minmax_4d(c)
 85        else:
 86            return torch.sigmoid(c)
 87
 88    @torch.no_grad()
 89    def get_contributions(self):
 90        """Calculate the contributions of the attention mechanism."""
 91        for name, param in self.fuser.named_parameters():
 92            if "weight" in name:
 93                weights = param.squeeze()
 94        channels = self.channels.copy()
 95        contribs = torch.stack(
 96            [
 97                a.sum(dim=1)
 98                for a in weights.softmax(dim=1).split(channels, dim=1)
 99            ],
100            dim=1,
101        )
102        return [f"{i}" for i in range(len(self.channels))], contribs
class TtameAttention(torch.nn.modules.module.Module):
 11class TtameAttention(nn.Module):
 12    """TtameAttention is a class that represents the attention mechanism used
 13    in TTAME."""
 14
 15    def __init__(self, ft_size: List[torch.Size], num_classes=1000):
 16        super().__init__()
 17        feat_height = ft_size[0][2] if ft_size[0][2] <= 56 else 56
 18        self.resolution = feat_height
 19        self.interpolate = lambda inp: F.interpolate(
 20            inp,
 21            size=(feat_height, feat_height),
 22            mode="bilinear",
 23            align_corners=False,
 24        )
 25        in_channels_list = [o[1] for o in ft_size]
 26        self.channels = in_channels_list
 27        # noinspection PyTypeChecker
 28        self.convs = nn.ModuleList(
 29            [
 30                nn.Conv2d(
 31                    in_channels=in_channels,
 32                    out_channels=in_channels,
 33                    kernel_size=1,
 34                    padding=0,
 35                    bias=True,
 36                )
 37                for in_channels in in_channels_list
 38            ]
 39        )
 40        self.bn_channels = in_channels_list
 41        self.bns = nn.ModuleList(
 42            [nn.BatchNorm2d(channels) for channels in self.bn_channels]
 43        )
 44        self.relu = nn.ReLU()
 45        # for each extra layer we need 1000 more channels to input to the fuse
 46        # convolution
 47        fuse_channels = sum(in_channels_list)
 48        # noinspection PyTypeChecker
 49        self.fuser = nn.Conv2d(
 50            in_channels=fuse_channels,
 51            out_channels=num_classes,
 52            kernel_size=1,
 53            padding=0,
 54            bias=True,
 55        )
 56        self.num_classes = num_classes
 57
 58    def forward(self, features, **kwargs):
 59        """Forward pass of the attention mechanism."""
 60        # Fusion Strategy
 61        feature_maps = features
 62        if kwargs.get("opticam", False):
 63            return features[0]
 64        # Now all feature map sets are of the same HxW
 65        # conv
 66        class_maps = [
 67            op(feature) for op, feature in zip(self.convs, feature_maps)
 68        ]
 69        # batch norm
 70        class_maps = [op(feature) for op, feature in zip(self.bns, class_maps)]
 71        # add (skip connection)
 72        class_maps = [
 73            class_map + feature_map
 74            for class_map, feature_map in zip(class_maps, feature_maps)
 75        ]
 76        # activation
 77        class_maps = [self.relu(class_map) for class_map in class_maps]
 78        # upscale
 79        class_maps = [self.interpolate(feature) for feature in class_maps]
 80        # concat
 81        class_maps = torch.cat(class_maps, 1)
 82        # fuse into num_classes channels
 83        c = self.fuser(class_maps)  # batch_size x1xWxH
 84        if not self.training:
 85            return minmax_4d(c)
 86        else:
 87            return torch.sigmoid(c)
 88
 89    @torch.no_grad()
 90    def get_contributions(self):
 91        """Calculate the contributions of the attention mechanism."""
 92        for name, param in self.fuser.named_parameters():
 93            if "weight" in name:
 94                weights = param.squeeze()
 95        channels = self.channels.copy()
 96        contribs = torch.stack(
 97            [
 98                a.sum(dim=1)
 99                for a in weights.softmax(dim=1).split(channels, dim=1)
100            ],
101            dim=1,
102        )
103        return [f"{i}" for i in range(len(self.channels))], contribs

TtameAttention is a class that represents the attention mechanism used in TTAME.

TtameAttention(ft_size: List[torch.Size], num_classes=1000)
15    def __init__(self, ft_size: List[torch.Size], num_classes=1000):
16        super().__init__()
17        feat_height = ft_size[0][2] if ft_size[0][2] <= 56 else 56
18        self.resolution = feat_height
19        self.interpolate = lambda inp: F.interpolate(
20            inp,
21            size=(feat_height, feat_height),
22            mode="bilinear",
23            align_corners=False,
24        )
25        in_channels_list = [o[1] for o in ft_size]
26        self.channels = in_channels_list
27        # noinspection PyTypeChecker
28        self.convs = nn.ModuleList(
29            [
30                nn.Conv2d(
31                    in_channels=in_channels,
32                    out_channels=in_channels,
33                    kernel_size=1,
34                    padding=0,
35                    bias=True,
36                )
37                for in_channels in in_channels_list
38            ]
39        )
40        self.bn_channels = in_channels_list
41        self.bns = nn.ModuleList(
42            [nn.BatchNorm2d(channels) for channels in self.bn_channels]
43        )
44        self.relu = nn.ReLU()
45        # for each extra layer we need 1000 more channels to input to the fuse
46        # convolution
47        fuse_channels = sum(in_channels_list)
48        # noinspection PyTypeChecker
49        self.fuser = nn.Conv2d(
50            in_channels=fuse_channels,
51            out_channels=num_classes,
52            kernel_size=1,
53            padding=0,
54            bias=True,
55        )
56        self.num_classes = num_classes

Initialize internal Module state, shared by both nn.Module and ScriptModule.

resolution
interpolate
channels
convs
bn_channels
bns
relu
fuser
num_classes
def forward(self, features, **kwargs):
58    def forward(self, features, **kwargs):
59        """Forward pass of the attention mechanism."""
60        # Fusion Strategy
61        feature_maps = features
62        if kwargs.get("opticam", False):
63            return features[0]
64        # Now all feature map sets are of the same HxW
65        # conv
66        class_maps = [
67            op(feature) for op, feature in zip(self.convs, feature_maps)
68        ]
69        # batch norm
70        class_maps = [op(feature) for op, feature in zip(self.bns, class_maps)]
71        # add (skip connection)
72        class_maps = [
73            class_map + feature_map
74            for class_map, feature_map in zip(class_maps, feature_maps)
75        ]
76        # activation
77        class_maps = [self.relu(class_map) for class_map in class_maps]
78        # upscale
79        class_maps = [self.interpolate(feature) for feature in class_maps]
80        # concat
81        class_maps = torch.cat(class_maps, 1)
82        # fuse into num_classes channels
83        c = self.fuser(class_maps)  # batch_size x1xWxH
84        if not self.training:
85            return minmax_4d(c)
86        else:
87            return torch.sigmoid(c)

Forward pass of the attention mechanism.

@torch.no_grad()
def get_contributions(self):
 89    @torch.no_grad()
 90    def get_contributions(self):
 91        """Calculate the contributions of the attention mechanism."""
 92        for name, param in self.fuser.named_parameters():
 93            if "weight" in name:
 94                weights = param.squeeze()
 95        channels = self.channels.copy()
 96        contribs = torch.stack(
 97            [
 98                a.sum(dim=1)
 99                for a in weights.softmax(dim=1).split(channels, dim=1)
100            ],
101            dim=1,
102        )
103        return [f"{i}" for i in range(len(self.channels))], contribs

Calculate the contributions of the attention mechanism.