ptame.models.components.attention.ptame_attention
1import math 2 3import torch 4from torch import nn 5from torchvision.models import resnet18 6from torchvision.models.feature_extraction import ( 7 create_feature_extractor, 8 get_graph_node_names, 9) 10 11from ptame.utils.masking_utils import minmax_4d 12 13 14class PTAMEAttention(nn.Module): 15 """PTAMEAttention is a class that represents the attention mechanism based 16 on a feature extractor and P-TAME.""" 17 18 def __init__( 19 self, 20 model: dict[str, list | nn.Module] = { 21 "model": resnet18(weights="DEFAULT"), 22 "layers": ["layer1", "layer2", "layer3", "layer4"], 23 }, 24 input_dim: list[int] | None = [2, 3, 224, 224], 25 scale_up: nn.Module = nn.Upsample(scale_factor=2, mode="bilinear"), 26 activation: nn.Module = nn.ReLU(), 27 fuser_bias: bool = True, 28 num_classes: int = 1000, 29 unfreeze: int = 0, 30 cascading: bool = True, 31 only_layer_train: bool = False, 32 normalizers: list = [torch.sigmoid, minmax_4d], 33 ): 34 super().__init__() 35 self.layer_names = model["layers"] 36 fx_dict = self._make_fx(**model, input_dim=input_dim) 37 self.attention = fx_dict["attention"] 38 self.attention.requires_grad_(False).eval() 39 self.ft_size = fx_dict["ft_size"] 40 self.output_name = fx_dict["output_name"] 41 self.scale_up = scale_up 42 self.num_classes = num_classes 43 44 self.resolution = self.ft_size[0][-1] 45 channels = [ft[0] for ft in self.ft_size] 46 self.channels = channels 47 self.convs = nn.ModuleList( 48 [ 49 nn.Conv2d( 50 in_channels=c, 51 out_channels=c, 52 kernel_size=1, 53 padding=0, 54 bias=True, 55 ) 56 for c in channels 57 ] 58 ) 59 self.bns = nn.ModuleList([nn.BatchNorm2d(c) for c in channels]) 60 self.act = activation 61 self.fuser = nn.Conv2d( 62 in_channels=sum(channels), 63 out_channels=self.num_classes, 64 kernel_size=1, 65 padding=0, 66 bias=fuser_bias, 67 ) 68 69 self.normalizers = normalizers 70 self.trainable_layers = [self.convs, self.bns, self.fuser] 71 72 @staticmethod 73 def list_processor( 74 ops_list: nn.ModuleList, feature_list: list[torch.Tensor] 75 ) -> list[torch.Tensor]: 76 """Apply a list of operations to a list of feature maps. 77 78 Args: 79 ops_list (nn.ModuleList): List of operations to apply. 80 feature_list (list[torch.Tensor]): List of feature maps. 81 82 Returns: 83 list[torch.Tensor]: List of feature maps with the operations applied. 84 """ 85 return [op(feature) for op, feature in zip(ops_list, feature_list)] 86 87 @staticmethod 88 def skip_connection( 89 a: list[torch.Tensor], b: list[torch.Tensor] 90 ) -> list[torch.Tensor]: 91 """Add the feature maps of two lists. 92 93 Args: 94 a (list[torch.Tensor]): First list of feature maps. 95 b (list[torch.Tensor]): Second list of feature maps. 96 97 Returns: 98 list[torch.Tensor]: List of feature maps with the two input lists added. 99 """ 100 return [a + b for a, b in zip(a, b)] 101 102 def list_activation( 103 self, feature_list: list[torch.Tensor] 104 ) -> list[torch.Tensor]: 105 """Apply the activation function to the feature maps. 106 107 Args: 108 feature_list (list[torch.Tensor]): List of feature maps. 109 110 Returns: 111 list[torch.Tensor]: List of feature maps with the activation function applied. 112 """ 113 return [self.act(feature) for feature in feature_list] 114 115 def upscale(self, feature_list: list[torch.Tensor]) -> list[torch.Tensor]: 116 """Upscale the feature maps to the original resolution. 117 118 Args: 119 feature_list (list[torch.Tensor]): List of feature maps. 120 121 Returns: 122 list[torch.Tensor]: List of upscaled feature maps. 123 """ 124 125 def up(map, times): 126 return up(self.scale_up(map), times - 1) if times > 0 else map 127 128 return [ 129 up( 130 feature, 131 int(math.log(self.resolution / feature.shape[-1], 2)), 132 ) 133 for feature in feature_list 134 ] 135 136 def forward( 137 self, 138 images: torch.Tensor, 139 with_outputs=False, 140 return_cms=False, 141 **kwargs, 142 ) -> torch.Tensor: 143 """Overwrite of attention forward method to include explanation 144 head.""" 145 y, ft_maps = self.attention(images) 146 147 class_maps = self.list_processor(self.convs, ft_maps) 148 class_maps = self.list_processor(self.bns, class_maps) 149 class_maps = self.skip_connection(class_maps, ft_maps) 150 class_maps = self.list_activation(class_maps) 151 class_maps = self.upscale(class_maps) 152 class_maps = torch.cat(class_maps, 1) 153 if return_cms: 154 return class_maps 155 c = self.fuser(class_maps) 156 if with_outputs: 157 return y, c 158 if self.training or self.normalizers[1] is None: 159 return self.normalizers[0](c) 160 else: 161 return self.normalizers[1](c) 162 163 def _make_fx(self, model, layers, input_dim) -> dict: 164 """Create the feature extractor from the model.""" 165 train_names, eval_names = get_graph_node_names(model) 166 if not layers: 167 print(train_names) 168 quit() 169 output = (train_names[-1], eval_names[-1]) 170 if output[0] != output[1]: 171 print( 172 "WARNING! THIS MODEL HAS DIFFERENT OUTPUTS FOR TRAIN AND EVAL MODE" 173 ) 174 output_name = output[0] 175 attention = create_feature_extractor( 176 model, return_nodes=layers + [output_name] 177 ) 178 if input_dim is not None: 179 inp = torch.randn(input_dim) 180 else: 181 inp = torch.randn(2, 3, 224, 224) 182 attention.eval() 183 outputs = attention(inp) 184 outputs.pop(output_name) 185 features = outputs.values() 186 feature_size = [o.shape[1:] for o in features] 187 188 def hook( 189 module, inputs, outputs 190 ) -> tuple[torch.Tensor, list[torch.Tensor]]: 191 y: torch.Tensor = outputs.pop(output_name) 192 features = list(outputs.values()) 193 return y, features 194 195 attention.register_forward_hook(hook) 196 return { 197 "attention": attention, 198 "ft_size": feature_size, 199 "output_name": output_name, 200 } 201 202 @torch.no_grad() 203 def get_contributions(self): 204 """Calculate the contributions of the attention mechanism.""" 205 for name, param in self.fuser.named_parameters(): 206 if "weight" in name: 207 weights = param.squeeze() 208 channels = self.channels.copy() 209 contribs = torch.stack( 210 [ 211 a.sum(dim=1) 212 for a in weights.softmax(dim=1).split(channels, dim=1) 213 ], 214 dim=1, 215 ) 216 return self.layer_names, contribs 217 218 219def minmax_compose(fn): 220 def wrapper(*args, **kwargs): 221 return minmax_4d(fn(*args, **kwargs)) 222 223 return wrapper
15class PTAMEAttention(nn.Module): 16 """PTAMEAttention is a class that represents the attention mechanism based 17 on a feature extractor and P-TAME.""" 18 19 def __init__( 20 self, 21 model: dict[str, list | nn.Module] = { 22 "model": resnet18(weights="DEFAULT"), 23 "layers": ["layer1", "layer2", "layer3", "layer4"], 24 }, 25 input_dim: list[int] | None = [2, 3, 224, 224], 26 scale_up: nn.Module = nn.Upsample(scale_factor=2, mode="bilinear"), 27 activation: nn.Module = nn.ReLU(), 28 fuser_bias: bool = True, 29 num_classes: int = 1000, 30 unfreeze: int = 0, 31 cascading: bool = True, 32 only_layer_train: bool = False, 33 normalizers: list = [torch.sigmoid, minmax_4d], 34 ): 35 super().__init__() 36 self.layer_names = model["layers"] 37 fx_dict = self._make_fx(**model, input_dim=input_dim) 38 self.attention = fx_dict["attention"] 39 self.attention.requires_grad_(False).eval() 40 self.ft_size = fx_dict["ft_size"] 41 self.output_name = fx_dict["output_name"] 42 self.scale_up = scale_up 43 self.num_classes = num_classes 44 45 self.resolution = self.ft_size[0][-1] 46 channels = [ft[0] for ft in self.ft_size] 47 self.channels = channels 48 self.convs = nn.ModuleList( 49 [ 50 nn.Conv2d( 51 in_channels=c, 52 out_channels=c, 53 kernel_size=1, 54 padding=0, 55 bias=True, 56 ) 57 for c in channels 58 ] 59 ) 60 self.bns = nn.ModuleList([nn.BatchNorm2d(c) for c in channels]) 61 self.act = activation 62 self.fuser = nn.Conv2d( 63 in_channels=sum(channels), 64 out_channels=self.num_classes, 65 kernel_size=1, 66 padding=0, 67 bias=fuser_bias, 68 ) 69 70 self.normalizers = normalizers 71 self.trainable_layers = [self.convs, self.bns, self.fuser] 72 73 @staticmethod 74 def list_processor( 75 ops_list: nn.ModuleList, feature_list: list[torch.Tensor] 76 ) -> list[torch.Tensor]: 77 """Apply a list of operations to a list of feature maps. 78 79 Args: 80 ops_list (nn.ModuleList): List of operations to apply. 81 feature_list (list[torch.Tensor]): List of feature maps. 82 83 Returns: 84 list[torch.Tensor]: List of feature maps with the operations applied. 85 """ 86 return [op(feature) for op, feature in zip(ops_list, feature_list)] 87 88 @staticmethod 89 def skip_connection( 90 a: list[torch.Tensor], b: list[torch.Tensor] 91 ) -> list[torch.Tensor]: 92 """Add the feature maps of two lists. 93 94 Args: 95 a (list[torch.Tensor]): First list of feature maps. 96 b (list[torch.Tensor]): Second list of feature maps. 97 98 Returns: 99 list[torch.Tensor]: List of feature maps with the two input lists added. 100 """ 101 return [a + b for a, b in zip(a, b)] 102 103 def list_activation( 104 self, feature_list: list[torch.Tensor] 105 ) -> list[torch.Tensor]: 106 """Apply the activation function to the feature maps. 107 108 Args: 109 feature_list (list[torch.Tensor]): List of feature maps. 110 111 Returns: 112 list[torch.Tensor]: List of feature maps with the activation function applied. 113 """ 114 return [self.act(feature) for feature in feature_list] 115 116 def upscale(self, feature_list: list[torch.Tensor]) -> list[torch.Tensor]: 117 """Upscale the feature maps to the original resolution. 118 119 Args: 120 feature_list (list[torch.Tensor]): List of feature maps. 121 122 Returns: 123 list[torch.Tensor]: List of upscaled feature maps. 124 """ 125 126 def up(map, times): 127 return up(self.scale_up(map), times - 1) if times > 0 else map 128 129 return [ 130 up( 131 feature, 132 int(math.log(self.resolution / feature.shape[-1], 2)), 133 ) 134 for feature in feature_list 135 ] 136 137 def forward( 138 self, 139 images: torch.Tensor, 140 with_outputs=False, 141 return_cms=False, 142 **kwargs, 143 ) -> torch.Tensor: 144 """Overwrite of attention forward method to include explanation 145 head.""" 146 y, ft_maps = self.attention(images) 147 148 class_maps = self.list_processor(self.convs, ft_maps) 149 class_maps = self.list_processor(self.bns, class_maps) 150 class_maps = self.skip_connection(class_maps, ft_maps) 151 class_maps = self.list_activation(class_maps) 152 class_maps = self.upscale(class_maps) 153 class_maps = torch.cat(class_maps, 1) 154 if return_cms: 155 return class_maps 156 c = self.fuser(class_maps) 157 if with_outputs: 158 return y, c 159 if self.training or self.normalizers[1] is None: 160 return self.normalizers[0](c) 161 else: 162 return self.normalizers[1](c) 163 164 def _make_fx(self, model, layers, input_dim) -> dict: 165 """Create the feature extractor from the model.""" 166 train_names, eval_names = get_graph_node_names(model) 167 if not layers: 168 print(train_names) 169 quit() 170 output = (train_names[-1], eval_names[-1]) 171 if output[0] != output[1]: 172 print( 173 "WARNING! THIS MODEL HAS DIFFERENT OUTPUTS FOR TRAIN AND EVAL MODE" 174 ) 175 output_name = output[0] 176 attention = create_feature_extractor( 177 model, return_nodes=layers + [output_name] 178 ) 179 if input_dim is not None: 180 inp = torch.randn(input_dim) 181 else: 182 inp = torch.randn(2, 3, 224, 224) 183 attention.eval() 184 outputs = attention(inp) 185 outputs.pop(output_name) 186 features = outputs.values() 187 feature_size = [o.shape[1:] for o in features] 188 189 def hook( 190 module, inputs, outputs 191 ) -> tuple[torch.Tensor, list[torch.Tensor]]: 192 y: torch.Tensor = outputs.pop(output_name) 193 features = list(outputs.values()) 194 return y, features 195 196 attention.register_forward_hook(hook) 197 return { 198 "attention": attention, 199 "ft_size": feature_size, 200 "output_name": output_name, 201 } 202 203 @torch.no_grad() 204 def get_contributions(self): 205 """Calculate the contributions of the attention mechanism.""" 206 for name, param in self.fuser.named_parameters(): 207 if "weight" in name: 208 weights = param.squeeze() 209 channels = self.channels.copy() 210 contribs = torch.stack( 211 [ 212 a.sum(dim=1) 213 for a in weights.softmax(dim=1).split(channels, dim=1) 214 ], 215 dim=1, 216 ) 217 return self.layer_names, contribs
PTAMEAttention is a class that represents the attention mechanism based on a feature extractor and P-TAME.
19 def __init__( 20 self, 21 model: dict[str, list | nn.Module] = { 22 "model": resnet18(weights="DEFAULT"), 23 "layers": ["layer1", "layer2", "layer3", "layer4"], 24 }, 25 input_dim: list[int] | None = [2, 3, 224, 224], 26 scale_up: nn.Module = nn.Upsample(scale_factor=2, mode="bilinear"), 27 activation: nn.Module = nn.ReLU(), 28 fuser_bias: bool = True, 29 num_classes: int = 1000, 30 unfreeze: int = 0, 31 cascading: bool = True, 32 only_layer_train: bool = False, 33 normalizers: list = [torch.sigmoid, minmax_4d], 34 ): 35 super().__init__() 36 self.layer_names = model["layers"] 37 fx_dict = self._make_fx(**model, input_dim=input_dim) 38 self.attention = fx_dict["attention"] 39 self.attention.requires_grad_(False).eval() 40 self.ft_size = fx_dict["ft_size"] 41 self.output_name = fx_dict["output_name"] 42 self.scale_up = scale_up 43 self.num_classes = num_classes 44 45 self.resolution = self.ft_size[0][-1] 46 channels = [ft[0] for ft in self.ft_size] 47 self.channels = channels 48 self.convs = nn.ModuleList( 49 [ 50 nn.Conv2d( 51 in_channels=c, 52 out_channels=c, 53 kernel_size=1, 54 padding=0, 55 bias=True, 56 ) 57 for c in channels 58 ] 59 ) 60 self.bns = nn.ModuleList([nn.BatchNorm2d(c) for c in channels]) 61 self.act = activation 62 self.fuser = nn.Conv2d( 63 in_channels=sum(channels), 64 out_channels=self.num_classes, 65 kernel_size=1, 66 padding=0, 67 bias=fuser_bias, 68 ) 69 70 self.normalizers = normalizers 71 self.trainable_layers = [self.convs, self.bns, self.fuser]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
73 @staticmethod 74 def list_processor( 75 ops_list: nn.ModuleList, feature_list: list[torch.Tensor] 76 ) -> list[torch.Tensor]: 77 """Apply a list of operations to a list of feature maps. 78 79 Args: 80 ops_list (nn.ModuleList): List of operations to apply. 81 feature_list (list[torch.Tensor]): List of feature maps. 82 83 Returns: 84 list[torch.Tensor]: List of feature maps with the operations applied. 85 """ 86 return [op(feature) for op, feature in zip(ops_list, feature_list)]
Apply a list of operations to a list of feature maps.
Args: ops_list (nn.ModuleList): List of operations to apply. feature_list (list[torch.Tensor]): List of feature maps.
Returns: list[torch.Tensor]: List of feature maps with the operations applied.
88 @staticmethod 89 def skip_connection( 90 a: list[torch.Tensor], b: list[torch.Tensor] 91 ) -> list[torch.Tensor]: 92 """Add the feature maps of two lists. 93 94 Args: 95 a (list[torch.Tensor]): First list of feature maps. 96 b (list[torch.Tensor]): Second list of feature maps. 97 98 Returns: 99 list[torch.Tensor]: List of feature maps with the two input lists added. 100 """ 101 return [a + b for a, b in zip(a, b)]
Add the feature maps of two lists.
Args: a (list[torch.Tensor]): First list of feature maps. b (list[torch.Tensor]): Second list of feature maps.
Returns: list[torch.Tensor]: List of feature maps with the two input lists added.
103 def list_activation( 104 self, feature_list: list[torch.Tensor] 105 ) -> list[torch.Tensor]: 106 """Apply the activation function to the feature maps. 107 108 Args: 109 feature_list (list[torch.Tensor]): List of feature maps. 110 111 Returns: 112 list[torch.Tensor]: List of feature maps with the activation function applied. 113 """ 114 return [self.act(feature) for feature in feature_list]
Apply the activation function to the feature maps.
Args: feature_list (list[torch.Tensor]): List of feature maps.
Returns: list[torch.Tensor]: List of feature maps with the activation function applied.
116 def upscale(self, feature_list: list[torch.Tensor]) -> list[torch.Tensor]: 117 """Upscale the feature maps to the original resolution. 118 119 Args: 120 feature_list (list[torch.Tensor]): List of feature maps. 121 122 Returns: 123 list[torch.Tensor]: List of upscaled feature maps. 124 """ 125 126 def up(map, times): 127 return up(self.scale_up(map), times - 1) if times > 0 else map 128 129 return [ 130 up( 131 feature, 132 int(math.log(self.resolution / feature.shape[-1], 2)), 133 ) 134 for feature in feature_list 135 ]
Upscale the feature maps to the original resolution.
Args: feature_list (list[torch.Tensor]): List of feature maps.
Returns: list[torch.Tensor]: List of upscaled feature maps.
137 def forward( 138 self, 139 images: torch.Tensor, 140 with_outputs=False, 141 return_cms=False, 142 **kwargs, 143 ) -> torch.Tensor: 144 """Overwrite of attention forward method to include explanation 145 head.""" 146 y, ft_maps = self.attention(images) 147 148 class_maps = self.list_processor(self.convs, ft_maps) 149 class_maps = self.list_processor(self.bns, class_maps) 150 class_maps = self.skip_connection(class_maps, ft_maps) 151 class_maps = self.list_activation(class_maps) 152 class_maps = self.upscale(class_maps) 153 class_maps = torch.cat(class_maps, 1) 154 if return_cms: 155 return class_maps 156 c = self.fuser(class_maps) 157 if with_outputs: 158 return y, c 159 if self.training or self.normalizers[1] is None: 160 return self.normalizers[0](c) 161 else: 162 return self.normalizers[1](c)
Overwrite of attention forward method to include explanation head.
203 @torch.no_grad() 204 def get_contributions(self): 205 """Calculate the contributions of the attention mechanism.""" 206 for name, param in self.fuser.named_parameters(): 207 if "weight" in name: 208 weights = param.squeeze() 209 channels = self.channels.copy() 210 contribs = torch.stack( 211 [ 212 a.sum(dim=1) 213 for a in weights.softmax(dim=1).split(channels, dim=1) 214 ], 215 dim=1, 216 ) 217 return self.layer_names, contribs
Calculate the contributions of the attention mechanism.