ptame.models.ptame_module
1from collections.abc import Callable, Iterator 2from functools import partial 3from pathlib import Path 4from typing import Any 5 6import torch 7from lightning import LightningModule 8from torch import nn 9from torchmetrics import MaxMetric, MeanMetric, Metric 10from torchmetrics.classification.accuracy import Accuracy 11 12from ptame.utils.measures import Composer 13 14from .components.loss import Loss 15 16 17class PTAMELitModule(LightningModule): 18 """`LightningModule` for PTAME. 19 20 A `LightningModule` implements 8 key methods: 21 22 ```python 23 def __init__(self): 24 # Define initialization code here. 25 26 def setup(self, stage): 27 # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'. 28 # This hook is called on every process when using DDP. 29 30 def training_step(self, batch, batch_idx): 31 # The complete training step. 32 33 def validation_step(self, batch, batch_idx): 34 # The complete validation step. 35 36 def test_step(self, batch, batch_idx): 37 # The complete test step. 38 39 def predict_step(self, batch, batch_idx): 40 # The complete predict step. 41 42 def configure_optimizers(self): 43 # Define and configure optimizers and LR schedulers. 44 ``` 45 46 Docs: 47 https://lightning.ai/docs/pytorch/latest/common/lightning_module.html 48 """ 49 50 def __init__( 51 self, 52 net: torch.nn.Module, 53 loss: Loss, 54 optimizer: Callable[ 55 [Iterator[torch.nn.Parameter]], torch.optim.Optimizer 56 ], 57 scheduler: dict[str, Any] | None, 58 val_measures: dict[str, Metric] | None, 59 test_measures: dict[str, Metric] | None, 60 feature_contribution: bool = False, 61 terminate_on_nan: bool = True, 62 compile: bool = False, 63 **kwargs, 64 ) -> None: 65 """Initialize a `PTAMELitModule`. 66 67 :param net: The model to train. 68 :param optimizer: The optimizer to use for training. 69 :param scheduler: The learning rate scheduler to use for training, 70 together with the options (how often it should update etc.). 71 :param val_measures: The explainability specific validation measures to 72 use. 73 :param test_measures: " test ". 74 :param compile: Whether to compile the model before training. 75 """ 76 super().__init__() 77 78 # this line allows to access init params with 'self.hparams' attribute 79 # also ensures init params will be stored in ckpt 80 self.save_hyperparameters( 81 ignore=[ 82 "net", 83 "loss", 84 "val_measures", 85 "test_measures", 86 "optimizer", 87 "scheduler", 88 ], 89 logger=False, 90 ) 91 92 self.net = net 93 94 # cut off initialization for simple restore 95 if loss is None: 96 return 97 98 # loss function 99 self.criterion = loss 100 101 # optimizer and scheduler 102 self.optimizer_cfg = optimizer 103 self.scheduler_cfg = scheduler 104 105 # metric objects are initialized in `setup` method 106 self.train_acc = None 107 self.val_acc = None 108 self.test_acc = None 109 # for averaging loss across batches 110 self.train_losses = nn.ModuleList( 111 [MeanMetric() for _ in range(self.criterion.num_terms)] 112 ) 113 self.val_losses = nn.ModuleList( 114 [MeanMetric() for _ in range(self.criterion.num_terms)] 115 ) 116 self.test_loss = MeanMetric() 117 # metric object for calculating and averaging ADIC or ROAD or across batches 118 self.val_measures = ( 119 Composer(nn.ModuleList(val_measures.values()), prefix="val/") 120 if val_measures 121 else None 122 ) 123 self.test_measures = ( 124 Composer(nn.ModuleList(test_measures.values()), prefix="test/") 125 if test_measures 126 else None 127 ) 128 129 # for tracking best so far validation accuracy 130 self.val_acc_best = MaxMetric() 131 132 def forward(self, x: torch.Tensor) -> torch.Tensor: 133 """Perform a forward pass through the model `self.net`. 134 135 :param x: A tensor of images. 136 :return: A tensor of logits. 137 """ 138 if self.hparams.compile: 139 return self.compiled_net(x) 140 return self.net(x) 141 142 def on_train_start(self) -> None: 143 """Lightning hook that is called when training begins.""" 144 # by default lightning executes validation step sanity checks before training starts, 145 # so it's worth to make sure validation metrics don't store results from these checks 146 [val_loss.reset() for val_loss in self.val_losses] 147 self.val_acc.reset() 148 self.val_acc_best.reset() 149 150 def model_step( 151 self, batch: tuple[torch.Tensor, torch.Tensor] 152 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 153 """Perform a single model step on a batch of data. 154 155 :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. 156 157 :return: A tuple containing (in order): 158 - A tensor of losses. 159 - A tensor of predictions. 160 - A tensor of target labels. 161 """ 162 x, y_ground = batch 163 out = self.forward(x) 164 losses = self.criterion(**out, epoch=self.trainer.current_epoch) 165 return losses, out 166 167 def training_step( 168 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 169 ) -> torch.Tensor: 170 """Perform a single training step on a batch of data from the training 171 set. 172 173 :param batch: A batch of data (a tuple) containing the input tensor of 174 images and target labels. 175 :param batch_idx: The index of the current batch. 176 :return: A tensor of losses between model predictions and targets. 177 """ 178 losses, out = self.model_step(batch) 179 preds_masked = out["logits_masked"] 180 targets = out["targets"] 181 # update and log metrics 182 for i, metric in enumerate(self.train_losses): 183 metric(losses[i]) # compute metric 184 self.log(f"train/loss[{i}]", metric, prog_bar=True) 185 self.train_acc(preds_masked, targets) 186 self.log("train/acc", self.train_acc, prog_bar=True) 187 188 # return loss or backpropagation will fail 189 if self.hparams.terminate_on_nan and losses[0].isnan().any(): 190 raise ValueError("NaN detected in loss!") 191 return losses[0] 192 193 def on_validation_epoch_start(self) -> None: 194 """Lightning hook that is called before a validation epoch begins.""" 195 if self.val_measures: 196 if self.hparams.compile: 197 self.val_measures.register_net(self.compiled_net) 198 else: 199 self.val_measures.register_net(self.net) 200 201 def validation_step( 202 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 203 ) -> None: 204 """Perform a single validation step on a batch of data from the 205 validation set. 206 207 :param batch: A batch of data (a tuple) containing the input tensor of 208 images and target labels. 209 :param batch_idx: The index of the current batch. 210 """ 211 losses, out = self.model_step(batch) 212 preds = out["logits"] 213 preds_masked = out["logits_masked"] 214 targets = out["targets"] 215 maps = out["masks"] 216 217 # update and log metrics 218 for i, metric in enumerate(self.val_losses): 219 metric.update(losses[i]) 220 self.log(f"val/loss[{i}]", metric) 221 self.val_acc.update(preds_masked, targets) 222 self.log("val/acc", self.val_acc) 223 if self.val_measures: 224 self.val_measures.update(batch[0], preds, targets, maps) 225 226 def on_validation_epoch_end(self) -> None: 227 "Lightning hook that is called when a validation epoch ends." 228 acc = self.val_acc.compute() # get current val acc 229 self.val_acc_best(acc) # update best so far val acc 230 # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object 231 # otherwise metric would be reset by lightning after each epoch 232 self.log( 233 "val/acc_best", 234 self.val_acc_best.compute(), 235 sync_dist=True, 236 prog_bar=True, 237 ) 238 if self.val_measures: 239 self.log_dict(self.val_measures.compute()) 240 self.val_measures.reset() 241 242 def on_test_epoch_start(self) -> None: 243 """Lightning hook that is called before a test epoch begins.""" 244 if self.test_measures: 245 if self.hparams.compile: 246 self.test_measures.register_net(self.compiled_net) 247 else: 248 self.test_measures.register_net(self.net) 249 250 def test_step( 251 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 252 ) -> None: 253 """Perform a single test step on a batch of data from the test set. 254 255 :param batch: A batch of data (a tuple) containing the input tensor of 256 images and target labels. 257 :param batch_idx: The index of the current batch. 258 """ 259 losses, out = self.model_step(batch) 260 preds = out["logits"] 261 preds_masked = out["logits_masked"] 262 targets = out["targets"] 263 maps = out["masks"] 264 265 # update and log metrics 266 self.test_loss.update(losses[0]) 267 self.log("test/loss", self.test_loss) 268 self.test_acc.update(preds_masked, targets) 269 self.log("test/acc", self.test_acc) 270 if save_masks := self.hparams.get("save_masks"): 271 Path(save_masks).mkdir(parents=True, exist_ok=True) 272 torch.save(maps, f"{save_masks}/{batch_idx}.pt") 273 if self.test_measures: 274 self.test_measures.update(batch[0], preds, targets, maps) 275 276 def on_test_epoch_end(self) -> None: 277 """Lightning hook that is called when a test epoch ends.""" 278 if self.hparams.feature_contribution: 279 layer_names, contribs = self.net.attention.get_contributions() 280 contribs_dict = { 281 layer: contrib 282 for layer, contrib in zip(layer_names, contribs.mean(dim=0)) 283 } 284 self.log_dict(contribs_dict) 285 286 if self.test_measures: 287 self.log_dict(self.test_measures.compute()) 288 self.test_measures.reset() 289 290 def setup(self, stage: str) -> None: 291 """Lightning hook that is called at the beginning of fit (train + 292 validate), validate, test, or predict. 293 294 This is a good hook when you need to build models dynamically or adjust 295 something about them. This hook is called on every process when using 296 DDP. 297 298 :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 299 """ 300 301 if self.hparams.compile: 302 self.compiled_net = torch.compile(self.net) 303 # metric objects for calculating and averaging accuracy across batches 304 self.train_acc = Accuracy( 305 task="multiclass", num_classes=self.trainer.datamodule.num_classes 306 ) 307 self.val_acc = Accuracy( 308 task="multiclass", num_classes=self.trainer.datamodule.num_classes 309 ) 310 self.test_acc = Accuracy( 311 task="multiclass", num_classes=self.trainer.datamodule.num_classes 312 ) 313 314 def teardown(self, stage: str) -> None: 315 """Lightning hook that is called at the end of fit (train + validate), 316 validate, test, or predict. 317 318 This is a good hook when you need to clean something up after the run. 319 320 :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 321 """ 322 if self.hparams.compile: 323 del self.compiled_net 324 325 def configure_optimizers(self) -> dict[str, Any]: 326 """Choose what optimizers and learning-rate schedulers to use in your 327 optimization. Normally you'd need one. But in the case of GANs or 328 similar you might have multiple. 329 330 Examples: 331 https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers 332 333 :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. 334 """ 335 optimizer = self.optimizer_cfg(params=self.trainer.model.parameters()) 336 if (scheduler_info := self.scheduler_cfg) is not None: 337 scheduler = scheduler_info["scheduler"] 338 if scheduler_info["name"] == "OneCycleLR": 339 scheduler = partial( 340 scheduler, 341 total_steps=self.trainer.estimated_stepping_batches, 342 ) 343 scheduler = scheduler(optimizer=optimizer) 344 scheduler_params = scheduler_info["params"] 345 return { 346 "optimizer": optimizer, 347 "lr_scheduler": { 348 "scheduler": scheduler, 349 "interval": scheduler_params["interval"], 350 "frequency": scheduler_params["frequency"], 351 "name": scheduler_params["name"], 352 }, 353 } 354 return {"optimizer": optimizer} 355 356 def on_save_checkpoint(self, checkpoint): 357 # pop the keys you are not interested by 358 sd = checkpoint["state_dict"] 359 names = list(sd.keys()) 360 for name in names: 361 if "backbone" in name or "_orig_mod" in name: 362 sd.pop(name) 363 if "net.attention.attention" in name: 364 sd.pop(name) 365 366 367if __name__ == "__main__": 368 _ = PTAMELitModule(None, None, None, None, None, None)
18class PTAMELitModule(LightningModule): 19 """`LightningModule` for PTAME. 20 21 A `LightningModule` implements 8 key methods: 22 23 ```python 24 def __init__(self): 25 # Define initialization code here. 26 27 def setup(self, stage): 28 # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'. 29 # This hook is called on every process when using DDP. 30 31 def training_step(self, batch, batch_idx): 32 # The complete training step. 33 34 def validation_step(self, batch, batch_idx): 35 # The complete validation step. 36 37 def test_step(self, batch, batch_idx): 38 # The complete test step. 39 40 def predict_step(self, batch, batch_idx): 41 # The complete predict step. 42 43 def configure_optimizers(self): 44 # Define and configure optimizers and LR schedulers. 45 ``` 46 47 Docs: 48 https://lightning.ai/docs/pytorch/latest/common/lightning_module.html 49 """ 50 51 def __init__( 52 self, 53 net: torch.nn.Module, 54 loss: Loss, 55 optimizer: Callable[ 56 [Iterator[torch.nn.Parameter]], torch.optim.Optimizer 57 ], 58 scheduler: dict[str, Any] | None, 59 val_measures: dict[str, Metric] | None, 60 test_measures: dict[str, Metric] | None, 61 feature_contribution: bool = False, 62 terminate_on_nan: bool = True, 63 compile: bool = False, 64 **kwargs, 65 ) -> None: 66 """Initialize a `PTAMELitModule`. 67 68 :param net: The model to train. 69 :param optimizer: The optimizer to use for training. 70 :param scheduler: The learning rate scheduler to use for training, 71 together with the options (how often it should update etc.). 72 :param val_measures: The explainability specific validation measures to 73 use. 74 :param test_measures: " test ". 75 :param compile: Whether to compile the model before training. 76 """ 77 super().__init__() 78 79 # this line allows to access init params with 'self.hparams' attribute 80 # also ensures init params will be stored in ckpt 81 self.save_hyperparameters( 82 ignore=[ 83 "net", 84 "loss", 85 "val_measures", 86 "test_measures", 87 "optimizer", 88 "scheduler", 89 ], 90 logger=False, 91 ) 92 93 self.net = net 94 95 # cut off initialization for simple restore 96 if loss is None: 97 return 98 99 # loss function 100 self.criterion = loss 101 102 # optimizer and scheduler 103 self.optimizer_cfg = optimizer 104 self.scheduler_cfg = scheduler 105 106 # metric objects are initialized in `setup` method 107 self.train_acc = None 108 self.val_acc = None 109 self.test_acc = None 110 # for averaging loss across batches 111 self.train_losses = nn.ModuleList( 112 [MeanMetric() for _ in range(self.criterion.num_terms)] 113 ) 114 self.val_losses = nn.ModuleList( 115 [MeanMetric() for _ in range(self.criterion.num_terms)] 116 ) 117 self.test_loss = MeanMetric() 118 # metric object for calculating and averaging ADIC or ROAD or across batches 119 self.val_measures = ( 120 Composer(nn.ModuleList(val_measures.values()), prefix="val/") 121 if val_measures 122 else None 123 ) 124 self.test_measures = ( 125 Composer(nn.ModuleList(test_measures.values()), prefix="test/") 126 if test_measures 127 else None 128 ) 129 130 # for tracking best so far validation accuracy 131 self.val_acc_best = MaxMetric() 132 133 def forward(self, x: torch.Tensor) -> torch.Tensor: 134 """Perform a forward pass through the model `self.net`. 135 136 :param x: A tensor of images. 137 :return: A tensor of logits. 138 """ 139 if self.hparams.compile: 140 return self.compiled_net(x) 141 return self.net(x) 142 143 def on_train_start(self) -> None: 144 """Lightning hook that is called when training begins.""" 145 # by default lightning executes validation step sanity checks before training starts, 146 # so it's worth to make sure validation metrics don't store results from these checks 147 [val_loss.reset() for val_loss in self.val_losses] 148 self.val_acc.reset() 149 self.val_acc_best.reset() 150 151 def model_step( 152 self, batch: tuple[torch.Tensor, torch.Tensor] 153 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 154 """Perform a single model step on a batch of data. 155 156 :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. 157 158 :return: A tuple containing (in order): 159 - A tensor of losses. 160 - A tensor of predictions. 161 - A tensor of target labels. 162 """ 163 x, y_ground = batch 164 out = self.forward(x) 165 losses = self.criterion(**out, epoch=self.trainer.current_epoch) 166 return losses, out 167 168 def training_step( 169 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 170 ) -> torch.Tensor: 171 """Perform a single training step on a batch of data from the training 172 set. 173 174 :param batch: A batch of data (a tuple) containing the input tensor of 175 images and target labels. 176 :param batch_idx: The index of the current batch. 177 :return: A tensor of losses between model predictions and targets. 178 """ 179 losses, out = self.model_step(batch) 180 preds_masked = out["logits_masked"] 181 targets = out["targets"] 182 # update and log metrics 183 for i, metric in enumerate(self.train_losses): 184 metric(losses[i]) # compute metric 185 self.log(f"train/loss[{i}]", metric, prog_bar=True) 186 self.train_acc(preds_masked, targets) 187 self.log("train/acc", self.train_acc, prog_bar=True) 188 189 # return loss or backpropagation will fail 190 if self.hparams.terminate_on_nan and losses[0].isnan().any(): 191 raise ValueError("NaN detected in loss!") 192 return losses[0] 193 194 def on_validation_epoch_start(self) -> None: 195 """Lightning hook that is called before a validation epoch begins.""" 196 if self.val_measures: 197 if self.hparams.compile: 198 self.val_measures.register_net(self.compiled_net) 199 else: 200 self.val_measures.register_net(self.net) 201 202 def validation_step( 203 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 204 ) -> None: 205 """Perform a single validation step on a batch of data from the 206 validation set. 207 208 :param batch: A batch of data (a tuple) containing the input tensor of 209 images and target labels. 210 :param batch_idx: The index of the current batch. 211 """ 212 losses, out = self.model_step(batch) 213 preds = out["logits"] 214 preds_masked = out["logits_masked"] 215 targets = out["targets"] 216 maps = out["masks"] 217 218 # update and log metrics 219 for i, metric in enumerate(self.val_losses): 220 metric.update(losses[i]) 221 self.log(f"val/loss[{i}]", metric) 222 self.val_acc.update(preds_masked, targets) 223 self.log("val/acc", self.val_acc) 224 if self.val_measures: 225 self.val_measures.update(batch[0], preds, targets, maps) 226 227 def on_validation_epoch_end(self) -> None: 228 "Lightning hook that is called when a validation epoch ends." 229 acc = self.val_acc.compute() # get current val acc 230 self.val_acc_best(acc) # update best so far val acc 231 # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object 232 # otherwise metric would be reset by lightning after each epoch 233 self.log( 234 "val/acc_best", 235 self.val_acc_best.compute(), 236 sync_dist=True, 237 prog_bar=True, 238 ) 239 if self.val_measures: 240 self.log_dict(self.val_measures.compute()) 241 self.val_measures.reset() 242 243 def on_test_epoch_start(self) -> None: 244 """Lightning hook that is called before a test epoch begins.""" 245 if self.test_measures: 246 if self.hparams.compile: 247 self.test_measures.register_net(self.compiled_net) 248 else: 249 self.test_measures.register_net(self.net) 250 251 def test_step( 252 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 253 ) -> None: 254 """Perform a single test step on a batch of data from the test set. 255 256 :param batch: A batch of data (a tuple) containing the input tensor of 257 images and target labels. 258 :param batch_idx: The index of the current batch. 259 """ 260 losses, out = self.model_step(batch) 261 preds = out["logits"] 262 preds_masked = out["logits_masked"] 263 targets = out["targets"] 264 maps = out["masks"] 265 266 # update and log metrics 267 self.test_loss.update(losses[0]) 268 self.log("test/loss", self.test_loss) 269 self.test_acc.update(preds_masked, targets) 270 self.log("test/acc", self.test_acc) 271 if save_masks := self.hparams.get("save_masks"): 272 Path(save_masks).mkdir(parents=True, exist_ok=True) 273 torch.save(maps, f"{save_masks}/{batch_idx}.pt") 274 if self.test_measures: 275 self.test_measures.update(batch[0], preds, targets, maps) 276 277 def on_test_epoch_end(self) -> None: 278 """Lightning hook that is called when a test epoch ends.""" 279 if self.hparams.feature_contribution: 280 layer_names, contribs = self.net.attention.get_contributions() 281 contribs_dict = { 282 layer: contrib 283 for layer, contrib in zip(layer_names, contribs.mean(dim=0)) 284 } 285 self.log_dict(contribs_dict) 286 287 if self.test_measures: 288 self.log_dict(self.test_measures.compute()) 289 self.test_measures.reset() 290 291 def setup(self, stage: str) -> None: 292 """Lightning hook that is called at the beginning of fit (train + 293 validate), validate, test, or predict. 294 295 This is a good hook when you need to build models dynamically or adjust 296 something about them. This hook is called on every process when using 297 DDP. 298 299 :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 300 """ 301 302 if self.hparams.compile: 303 self.compiled_net = torch.compile(self.net) 304 # metric objects for calculating and averaging accuracy across batches 305 self.train_acc = Accuracy( 306 task="multiclass", num_classes=self.trainer.datamodule.num_classes 307 ) 308 self.val_acc = Accuracy( 309 task="multiclass", num_classes=self.trainer.datamodule.num_classes 310 ) 311 self.test_acc = Accuracy( 312 task="multiclass", num_classes=self.trainer.datamodule.num_classes 313 ) 314 315 def teardown(self, stage: str) -> None: 316 """Lightning hook that is called at the end of fit (train + validate), 317 validate, test, or predict. 318 319 This is a good hook when you need to clean something up after the run. 320 321 :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 322 """ 323 if self.hparams.compile: 324 del self.compiled_net 325 326 def configure_optimizers(self) -> dict[str, Any]: 327 """Choose what optimizers and learning-rate schedulers to use in your 328 optimization. Normally you'd need one. But in the case of GANs or 329 similar you might have multiple. 330 331 Examples: 332 https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers 333 334 :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. 335 """ 336 optimizer = self.optimizer_cfg(params=self.trainer.model.parameters()) 337 if (scheduler_info := self.scheduler_cfg) is not None: 338 scheduler = scheduler_info["scheduler"] 339 if scheduler_info["name"] == "OneCycleLR": 340 scheduler = partial( 341 scheduler, 342 total_steps=self.trainer.estimated_stepping_batches, 343 ) 344 scheduler = scheduler(optimizer=optimizer) 345 scheduler_params = scheduler_info["params"] 346 return { 347 "optimizer": optimizer, 348 "lr_scheduler": { 349 "scheduler": scheduler, 350 "interval": scheduler_params["interval"], 351 "frequency": scheduler_params["frequency"], 352 "name": scheduler_params["name"], 353 }, 354 } 355 return {"optimizer": optimizer} 356 357 def on_save_checkpoint(self, checkpoint): 358 # pop the keys you are not interested by 359 sd = checkpoint["state_dict"] 360 names = list(sd.keys()) 361 for name in names: 362 if "backbone" in name or "_orig_mod" in name: 363 sd.pop(name) 364 if "net.attention.attention" in name: 365 sd.pop(name)
LightningModule for PTAME.
A LightningModule implements 8 key methods:
def __init__(self):
# Define initialization code here.
def setup(self, stage):
# Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
# This hook is called on every process when using DDP.
def training_step(self, batch, batch_idx):
# The complete training step.
def validation_step(self, batch, batch_idx):
# The complete validation step.
def test_step(self, batch, batch_idx):
# The complete test step.
def predict_step(self, batch, batch_idx):
# The complete predict step.
def configure_optimizers(self):
# Define and configure optimizers and LR schedulers.
Docs: https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
51 def __init__( 52 self, 53 net: torch.nn.Module, 54 loss: Loss, 55 optimizer: Callable[ 56 [Iterator[torch.nn.Parameter]], torch.optim.Optimizer 57 ], 58 scheduler: dict[str, Any] | None, 59 val_measures: dict[str, Metric] | None, 60 test_measures: dict[str, Metric] | None, 61 feature_contribution: bool = False, 62 terminate_on_nan: bool = True, 63 compile: bool = False, 64 **kwargs, 65 ) -> None: 66 """Initialize a `PTAMELitModule`. 67 68 :param net: The model to train. 69 :param optimizer: The optimizer to use for training. 70 :param scheduler: The learning rate scheduler to use for training, 71 together with the options (how often it should update etc.). 72 :param val_measures: The explainability specific validation measures to 73 use. 74 :param test_measures: " test ". 75 :param compile: Whether to compile the model before training. 76 """ 77 super().__init__() 78 79 # this line allows to access init params with 'self.hparams' attribute 80 # also ensures init params will be stored in ckpt 81 self.save_hyperparameters( 82 ignore=[ 83 "net", 84 "loss", 85 "val_measures", 86 "test_measures", 87 "optimizer", 88 "scheduler", 89 ], 90 logger=False, 91 ) 92 93 self.net = net 94 95 # cut off initialization for simple restore 96 if loss is None: 97 return 98 99 # loss function 100 self.criterion = loss 101 102 # optimizer and scheduler 103 self.optimizer_cfg = optimizer 104 self.scheduler_cfg = scheduler 105 106 # metric objects are initialized in `setup` method 107 self.train_acc = None 108 self.val_acc = None 109 self.test_acc = None 110 # for averaging loss across batches 111 self.train_losses = nn.ModuleList( 112 [MeanMetric() for _ in range(self.criterion.num_terms)] 113 ) 114 self.val_losses = nn.ModuleList( 115 [MeanMetric() for _ in range(self.criterion.num_terms)] 116 ) 117 self.test_loss = MeanMetric() 118 # metric object for calculating and averaging ADIC or ROAD or across batches 119 self.val_measures = ( 120 Composer(nn.ModuleList(val_measures.values()), prefix="val/") 121 if val_measures 122 else None 123 ) 124 self.test_measures = ( 125 Composer(nn.ModuleList(test_measures.values()), prefix="test/") 126 if test_measures 127 else None 128 ) 129 130 # for tracking best so far validation accuracy 131 self.val_acc_best = MaxMetric()
Initialize a PTAMELitModule.
Parameters
- net: The model to train.
- optimizer: The optimizer to use for training.
- scheduler: The learning rate scheduler to use for training, together with the options (how often it should update etc.).
- val_measures: The explainability specific validation measures to use.
- test_measures: " test ".
- compile: Whether to compile the model before training.
133 def forward(self, x: torch.Tensor) -> torch.Tensor: 134 """Perform a forward pass through the model `self.net`. 135 136 :param x: A tensor of images. 137 :return: A tensor of logits. 138 """ 139 if self.hparams.compile: 140 return self.compiled_net(x) 141 return self.net(x)
Perform a forward pass through the model self.net.
Parameters
- x: A tensor of images.
Returns
A tensor of logits.
143 def on_train_start(self) -> None: 144 """Lightning hook that is called when training begins.""" 145 # by default lightning executes validation step sanity checks before training starts, 146 # so it's worth to make sure validation metrics don't store results from these checks 147 [val_loss.reset() for val_loss in self.val_losses] 148 self.val_acc.reset() 149 self.val_acc_best.reset()
Lightning hook that is called when training begins.
151 def model_step( 152 self, batch: tuple[torch.Tensor, torch.Tensor] 153 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 154 """Perform a single model step on a batch of data. 155 156 :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. 157 158 :return: A tuple containing (in order): 159 - A tensor of losses. 160 - A tensor of predictions. 161 - A tensor of target labels. 162 """ 163 x, y_ground = batch 164 out = self.forward(x) 165 losses = self.criterion(**out, epoch=self.trainer.current_epoch) 166 return losses, out
Perform a single model step on a batch of data.
Parameters
- batch: A batch of data (a tuple) containing the input tensor of images and target labels.
Returns
A tuple containing (in order): - A tensor of losses. - A tensor of predictions. - A tensor of target labels.
168 def training_step( 169 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 170 ) -> torch.Tensor: 171 """Perform a single training step on a batch of data from the training 172 set. 173 174 :param batch: A batch of data (a tuple) containing the input tensor of 175 images and target labels. 176 :param batch_idx: The index of the current batch. 177 :return: A tensor of losses between model predictions and targets. 178 """ 179 losses, out = self.model_step(batch) 180 preds_masked = out["logits_masked"] 181 targets = out["targets"] 182 # update and log metrics 183 for i, metric in enumerate(self.train_losses): 184 metric(losses[i]) # compute metric 185 self.log(f"train/loss[{i}]", metric, prog_bar=True) 186 self.train_acc(preds_masked, targets) 187 self.log("train/acc", self.train_acc, prog_bar=True) 188 189 # return loss or backpropagation will fail 190 if self.hparams.terminate_on_nan and losses[0].isnan().any(): 191 raise ValueError("NaN detected in loss!") 192 return losses[0]
Perform a single training step on a batch of data from the training set.
Parameters
- batch: A batch of data (a tuple) containing the input tensor of images and target labels.
- batch_idx: The index of the current batch.
Returns
A tensor of losses between model predictions and targets.
194 def on_validation_epoch_start(self) -> None: 195 """Lightning hook that is called before a validation epoch begins.""" 196 if self.val_measures: 197 if self.hparams.compile: 198 self.val_measures.register_net(self.compiled_net) 199 else: 200 self.val_measures.register_net(self.net)
Lightning hook that is called before a validation epoch begins.
202 def validation_step( 203 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 204 ) -> None: 205 """Perform a single validation step on a batch of data from the 206 validation set. 207 208 :param batch: A batch of data (a tuple) containing the input tensor of 209 images and target labels. 210 :param batch_idx: The index of the current batch. 211 """ 212 losses, out = self.model_step(batch) 213 preds = out["logits"] 214 preds_masked = out["logits_masked"] 215 targets = out["targets"] 216 maps = out["masks"] 217 218 # update and log metrics 219 for i, metric in enumerate(self.val_losses): 220 metric.update(losses[i]) 221 self.log(f"val/loss[{i}]", metric) 222 self.val_acc.update(preds_masked, targets) 223 self.log("val/acc", self.val_acc) 224 if self.val_measures: 225 self.val_measures.update(batch[0], preds, targets, maps)
Perform a single validation step on a batch of data from the validation set.
Parameters
- batch: A batch of data (a tuple) containing the input tensor of images and target labels.
- batch_idx: The index of the current batch.
227 def on_validation_epoch_end(self) -> None: 228 "Lightning hook that is called when a validation epoch ends." 229 acc = self.val_acc.compute() # get current val acc 230 self.val_acc_best(acc) # update best so far val acc 231 # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object 232 # otherwise metric would be reset by lightning after each epoch 233 self.log( 234 "val/acc_best", 235 self.val_acc_best.compute(), 236 sync_dist=True, 237 prog_bar=True, 238 ) 239 if self.val_measures: 240 self.log_dict(self.val_measures.compute()) 241 self.val_measures.reset()
Lightning hook that is called when a validation epoch ends.
243 def on_test_epoch_start(self) -> None: 244 """Lightning hook that is called before a test epoch begins.""" 245 if self.test_measures: 246 if self.hparams.compile: 247 self.test_measures.register_net(self.compiled_net) 248 else: 249 self.test_measures.register_net(self.net)
Lightning hook that is called before a test epoch begins.
251 def test_step( 252 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 253 ) -> None: 254 """Perform a single test step on a batch of data from the test set. 255 256 :param batch: A batch of data (a tuple) containing the input tensor of 257 images and target labels. 258 :param batch_idx: The index of the current batch. 259 """ 260 losses, out = self.model_step(batch) 261 preds = out["logits"] 262 preds_masked = out["logits_masked"] 263 targets = out["targets"] 264 maps = out["masks"] 265 266 # update and log metrics 267 self.test_loss.update(losses[0]) 268 self.log("test/loss", self.test_loss) 269 self.test_acc.update(preds_masked, targets) 270 self.log("test/acc", self.test_acc) 271 if save_masks := self.hparams.get("save_masks"): 272 Path(save_masks).mkdir(parents=True, exist_ok=True) 273 torch.save(maps, f"{save_masks}/{batch_idx}.pt") 274 if self.test_measures: 275 self.test_measures.update(batch[0], preds, targets, maps)
Perform a single test step on a batch of data from the test set.
Parameters
- batch: A batch of data (a tuple) containing the input tensor of images and target labels.
- batch_idx: The index of the current batch.
277 def on_test_epoch_end(self) -> None: 278 """Lightning hook that is called when a test epoch ends.""" 279 if self.hparams.feature_contribution: 280 layer_names, contribs = self.net.attention.get_contributions() 281 contribs_dict = { 282 layer: contrib 283 for layer, contrib in zip(layer_names, contribs.mean(dim=0)) 284 } 285 self.log_dict(contribs_dict) 286 287 if self.test_measures: 288 self.log_dict(self.test_measures.compute()) 289 self.test_measures.reset()
Lightning hook that is called when a test epoch ends.
291 def setup(self, stage: str) -> None: 292 """Lightning hook that is called at the beginning of fit (train + 293 validate), validate, test, or predict. 294 295 This is a good hook when you need to build models dynamically or adjust 296 something about them. This hook is called on every process when using 297 DDP. 298 299 :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 300 """ 301 302 if self.hparams.compile: 303 self.compiled_net = torch.compile(self.net) 304 # metric objects for calculating and averaging accuracy across batches 305 self.train_acc = Accuracy( 306 task="multiclass", num_classes=self.trainer.datamodule.num_classes 307 ) 308 self.val_acc = Accuracy( 309 task="multiclass", num_classes=self.trainer.datamodule.num_classes 310 ) 311 self.test_acc = Accuracy( 312 task="multiclass", num_classes=self.trainer.datamodule.num_classes 313 )
Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.
This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
Parameters
- stage: Either
"fit","validate","test", or"predict".
315 def teardown(self, stage: str) -> None: 316 """Lightning hook that is called at the end of fit (train + validate), 317 validate, test, or predict. 318 319 This is a good hook when you need to clean something up after the run. 320 321 :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 322 """ 323 if self.hparams.compile: 324 del self.compiled_net
Lightning hook that is called at the end of fit (train + validate), validate, test, or predict.
This is a good hook when you need to clean something up after the run.
Parameters
- stage: Either
"fit","validate","test", or"predict".
326 def configure_optimizers(self) -> dict[str, Any]: 327 """Choose what optimizers and learning-rate schedulers to use in your 328 optimization. Normally you'd need one. But in the case of GANs or 329 similar you might have multiple. 330 331 Examples: 332 https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers 333 334 :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. 335 """ 336 optimizer = self.optimizer_cfg(params=self.trainer.model.parameters()) 337 if (scheduler_info := self.scheduler_cfg) is not None: 338 scheduler = scheduler_info["scheduler"] 339 if scheduler_info["name"] == "OneCycleLR": 340 scheduler = partial( 341 scheduler, 342 total_steps=self.trainer.estimated_stepping_batches, 343 ) 344 scheduler = scheduler(optimizer=optimizer) 345 scheduler_params = scheduler_info["params"] 346 return { 347 "optimizer": optimizer, 348 "lr_scheduler": { 349 "scheduler": scheduler, 350 "interval": scheduler_params["interval"], 351 "frequency": scheduler_params["frequency"], 352 "name": scheduler_params["name"], 353 }, 354 } 355 return {"optimizer": optimizer}
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple.
Examples: https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
Returns
A dict containing the configured optimizers and learning-rate schedulers to be used for training.
357 def on_save_checkpoint(self, checkpoint): 358 # pop the keys you are not interested by 359 sd = checkpoint["state_dict"] 360 names = list(sd.keys()) 361 for name in names: 362 if "backbone" in name or "_orig_mod" in name: 363 sd.pop(name) 364 if "net.attention.attention" in name: 365 sd.pop(name)
Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
Args: checkpoint: The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.
Example::
def on_save_checkpoint(self, checkpoint):
# 99% of use cases you don't need to implement this method
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note: Lightning saves all aspects of training (epoch, global step, etc...) including amp scaling. There is no need for you to store anything about training.