ptame.models.components.attention.ttame_attention
1from typing import List 2 3import torch 4from torch import nn 5from torch.nn import functional as F 6 7from ptame.utils.masking_utils import minmax_4d 8 9 10class TtameAttention(nn.Module): 11 """TtameAttention is a class that represents the attention mechanism used 12 in TTAME.""" 13 14 def __init__(self, ft_size: List[torch.Size], num_classes=1000): 15 super().__init__() 16 feat_height = ft_size[0][2] if ft_size[0][2] <= 56 else 56 17 self.resolution = feat_height 18 self.interpolate = lambda inp: F.interpolate( 19 inp, 20 size=(feat_height, feat_height), 21 mode="bilinear", 22 align_corners=False, 23 ) 24 in_channels_list = [o[1] for o in ft_size] 25 self.channels = in_channels_list 26 # noinspection PyTypeChecker 27 self.convs = nn.ModuleList( 28 [ 29 nn.Conv2d( 30 in_channels=in_channels, 31 out_channels=in_channels, 32 kernel_size=1, 33 padding=0, 34 bias=True, 35 ) 36 for in_channels in in_channels_list 37 ] 38 ) 39 self.bn_channels = in_channels_list 40 self.bns = nn.ModuleList( 41 [nn.BatchNorm2d(channels) for channels in self.bn_channels] 42 ) 43 self.relu = nn.ReLU() 44 # for each extra layer we need 1000 more channels to input to the fuse 45 # convolution 46 fuse_channels = sum(in_channels_list) 47 # noinspection PyTypeChecker 48 self.fuser = nn.Conv2d( 49 in_channels=fuse_channels, 50 out_channels=num_classes, 51 kernel_size=1, 52 padding=0, 53 bias=True, 54 ) 55 self.num_classes = num_classes 56 57 def forward(self, features, **kwargs): 58 """Forward pass of the attention mechanism.""" 59 # Fusion Strategy 60 feature_maps = features 61 if kwargs.get("opticam", False): 62 return features[0] 63 # Now all feature map sets are of the same HxW 64 # conv 65 class_maps = [ 66 op(feature) for op, feature in zip(self.convs, feature_maps) 67 ] 68 # batch norm 69 class_maps = [op(feature) for op, feature in zip(self.bns, class_maps)] 70 # add (skip connection) 71 class_maps = [ 72 class_map + feature_map 73 for class_map, feature_map in zip(class_maps, feature_maps) 74 ] 75 # activation 76 class_maps = [self.relu(class_map) for class_map in class_maps] 77 # upscale 78 class_maps = [self.interpolate(feature) for feature in class_maps] 79 # concat 80 class_maps = torch.cat(class_maps, 1) 81 # fuse into num_classes channels 82 c = self.fuser(class_maps) # batch_size x1xWxH 83 if not self.training: 84 return minmax_4d(c) 85 else: 86 return torch.sigmoid(c) 87 88 @torch.no_grad() 89 def get_contributions(self): 90 """Calculate the contributions of the attention mechanism.""" 91 for name, param in self.fuser.named_parameters(): 92 if "weight" in name: 93 weights = param.squeeze() 94 channels = self.channels.copy() 95 contribs = torch.stack( 96 [ 97 a.sum(dim=1) 98 for a in weights.softmax(dim=1).split(channels, dim=1) 99 ], 100 dim=1, 101 ) 102 return [f"{i}" for i in range(len(self.channels))], contribs
class
TtameAttention(torch.nn.modules.module.Module):
11class TtameAttention(nn.Module): 12 """TtameAttention is a class that represents the attention mechanism used 13 in TTAME.""" 14 15 def __init__(self, ft_size: List[torch.Size], num_classes=1000): 16 super().__init__() 17 feat_height = ft_size[0][2] if ft_size[0][2] <= 56 else 56 18 self.resolution = feat_height 19 self.interpolate = lambda inp: F.interpolate( 20 inp, 21 size=(feat_height, feat_height), 22 mode="bilinear", 23 align_corners=False, 24 ) 25 in_channels_list = [o[1] for o in ft_size] 26 self.channels = in_channels_list 27 # noinspection PyTypeChecker 28 self.convs = nn.ModuleList( 29 [ 30 nn.Conv2d( 31 in_channels=in_channels, 32 out_channels=in_channels, 33 kernel_size=1, 34 padding=0, 35 bias=True, 36 ) 37 for in_channels in in_channels_list 38 ] 39 ) 40 self.bn_channels = in_channels_list 41 self.bns = nn.ModuleList( 42 [nn.BatchNorm2d(channels) for channels in self.bn_channels] 43 ) 44 self.relu = nn.ReLU() 45 # for each extra layer we need 1000 more channels to input to the fuse 46 # convolution 47 fuse_channels = sum(in_channels_list) 48 # noinspection PyTypeChecker 49 self.fuser = nn.Conv2d( 50 in_channels=fuse_channels, 51 out_channels=num_classes, 52 kernel_size=1, 53 padding=0, 54 bias=True, 55 ) 56 self.num_classes = num_classes 57 58 def forward(self, features, **kwargs): 59 """Forward pass of the attention mechanism.""" 60 # Fusion Strategy 61 feature_maps = features 62 if kwargs.get("opticam", False): 63 return features[0] 64 # Now all feature map sets are of the same HxW 65 # conv 66 class_maps = [ 67 op(feature) for op, feature in zip(self.convs, feature_maps) 68 ] 69 # batch norm 70 class_maps = [op(feature) for op, feature in zip(self.bns, class_maps)] 71 # add (skip connection) 72 class_maps = [ 73 class_map + feature_map 74 for class_map, feature_map in zip(class_maps, feature_maps) 75 ] 76 # activation 77 class_maps = [self.relu(class_map) for class_map in class_maps] 78 # upscale 79 class_maps = [self.interpolate(feature) for feature in class_maps] 80 # concat 81 class_maps = torch.cat(class_maps, 1) 82 # fuse into num_classes channels 83 c = self.fuser(class_maps) # batch_size x1xWxH 84 if not self.training: 85 return minmax_4d(c) 86 else: 87 return torch.sigmoid(c) 88 89 @torch.no_grad() 90 def get_contributions(self): 91 """Calculate the contributions of the attention mechanism.""" 92 for name, param in self.fuser.named_parameters(): 93 if "weight" in name: 94 weights = param.squeeze() 95 channels = self.channels.copy() 96 contribs = torch.stack( 97 [ 98 a.sum(dim=1) 99 for a in weights.softmax(dim=1).split(channels, dim=1) 100 ], 101 dim=1, 102 ) 103 return [f"{i}" for i in range(len(self.channels))], contribs
TtameAttention is a class that represents the attention mechanism used in TTAME.
TtameAttention(ft_size: List[torch.Size], num_classes=1000)
15 def __init__(self, ft_size: List[torch.Size], num_classes=1000): 16 super().__init__() 17 feat_height = ft_size[0][2] if ft_size[0][2] <= 56 else 56 18 self.resolution = feat_height 19 self.interpolate = lambda inp: F.interpolate( 20 inp, 21 size=(feat_height, feat_height), 22 mode="bilinear", 23 align_corners=False, 24 ) 25 in_channels_list = [o[1] for o in ft_size] 26 self.channels = in_channels_list 27 # noinspection PyTypeChecker 28 self.convs = nn.ModuleList( 29 [ 30 nn.Conv2d( 31 in_channels=in_channels, 32 out_channels=in_channels, 33 kernel_size=1, 34 padding=0, 35 bias=True, 36 ) 37 for in_channels in in_channels_list 38 ] 39 ) 40 self.bn_channels = in_channels_list 41 self.bns = nn.ModuleList( 42 [nn.BatchNorm2d(channels) for channels in self.bn_channels] 43 ) 44 self.relu = nn.ReLU() 45 # for each extra layer we need 1000 more channels to input to the fuse 46 # convolution 47 fuse_channels = sum(in_channels_list) 48 # noinspection PyTypeChecker 49 self.fuser = nn.Conv2d( 50 in_channels=fuse_channels, 51 out_channels=num_classes, 52 kernel_size=1, 53 padding=0, 54 bias=True, 55 ) 56 self.num_classes = num_classes
Initialize internal Module state, shared by both nn.Module and ScriptModule.
def
forward(self, features, **kwargs):
58 def forward(self, features, **kwargs): 59 """Forward pass of the attention mechanism.""" 60 # Fusion Strategy 61 feature_maps = features 62 if kwargs.get("opticam", False): 63 return features[0] 64 # Now all feature map sets are of the same HxW 65 # conv 66 class_maps = [ 67 op(feature) for op, feature in zip(self.convs, feature_maps) 68 ] 69 # batch norm 70 class_maps = [op(feature) for op, feature in zip(self.bns, class_maps)] 71 # add (skip connection) 72 class_maps = [ 73 class_map + feature_map 74 for class_map, feature_map in zip(class_maps, feature_maps) 75 ] 76 # activation 77 class_maps = [self.relu(class_map) for class_map in class_maps] 78 # upscale 79 class_maps = [self.interpolate(feature) for feature in class_maps] 80 # concat 81 class_maps = torch.cat(class_maps, 1) 82 # fuse into num_classes channels 83 c = self.fuser(class_maps) # batch_size x1xWxH 84 if not self.training: 85 return minmax_4d(c) 86 else: 87 return torch.sigmoid(c)
Forward pass of the attention mechanism.
@torch.no_grad()
def
get_contributions(self):
89 @torch.no_grad() 90 def get_contributions(self): 91 """Calculate the contributions of the attention mechanism.""" 92 for name, param in self.fuser.named_parameters(): 93 if "weight" in name: 94 weights = param.squeeze() 95 channels = self.channels.copy() 96 contribs = torch.stack( 97 [ 98 a.sum(dim=1) 99 for a in weights.softmax(dim=1).split(channels, dim=1) 100 ], 101 dim=1, 102 ) 103 return [f"{i}" for i in range(len(self.channels))], contribs
Calculate the contributions of the attention mechanism.