ptame.models.cv_module
1from collections.abc import Callable, Iterator 2from functools import partial 3from typing import Any 4 5import torch 6from lightning import LightningModule 7from torch import Tensor 8from torchmetrics import MaxMetric, MeanMetric 9from torchmetrics.classification.accuracy import Accuracy 10 11 12class CVModule(LightningModule): 13 """`LightningModule` for PAMELA. 14 15 A `LightningModule` implements 8 key methods: 16 17 ```python 18 def __init__(self): 19 # Define initialization code here. 20 21 def setup(self, stage): 22 # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'. 23 # This hook is called on every process when using DDP. 24 25 def training_step(self, batch, batch_idx): 26 # The complete training step. 27 28 def validation_step(self, batch, batch_idx): 29 # The complete validation step. 30 31 def test_step(self, batch, batch_idx): 32 # The complete test step. 33 34 def predict_step(self, batch, batch_idx): 35 # The complete predict step. 36 37 def configure_optimizers(self): 38 # Define and configure optimizers and LR schedulers. 39 ``` 40 41 Docs: 42 https://lightning.ai/docs/pytorch/latest/common/lightning_module.html 43 """ 44 45 def __init__( 46 self, 47 net: torch.nn.Module, 48 loss: Callable[[Tensor], Tensor], 49 optimizer: Callable[ 50 [Iterator[torch.nn.Parameter]], torch.optim.Optimizer 51 ], 52 scheduler: dict[str, Any] | None, 53 teacher: Callable[[Tensor], Tensor] | None = None, 54 terminate_on_nan: bool = True, 55 compile: bool = False, 56 **kwargs, 57 ) -> None: 58 """Initialize a `PAMELALitModule`. 59 60 :param net: The model to train. 61 :param optimizer: The optimizer to use for training. 62 :param scheduler: The learning rate scheduler to use for training, 63 together with the options (how often it should update etc.). 64 :param val_measures: The explainability specific validation measures to 65 use. 66 :param test_measures: " test ". 67 :param compile: Whether to compile the model before training. 68 """ 69 super().__init__() 70 71 # this line allows to access init params with 'self.hparams' attribute 72 # also ensures init params will be stored in ckpt 73 self.save_hyperparameters( 74 ignore=["net", "loss", "optimizer", "scheduler", "teacher"], 75 logger=False, 76 ) 77 78 self.net = net 79 self.teacher = None 80 if teacher is not None: 81 self.teacher = teacher 82 loss = self.kl_loss 83 self.teacher.eval() 84 self.teacher.requires_grad_(False) 85 86 # cut off initialization for simple restore 87 if loss is None: 88 return 89 90 # loss function 91 self.criterion = loss 92 93 # optimizer and scheduler 94 self.optimizer_cfg = optimizer 95 self.scheduler_cfg = scheduler 96 97 # metric objects are initialized in `setup` method 98 self.train_acc = None 99 self.val_acc = None 100 self.test_acc = None 101 # for averaging loss across batches 102 self.train_loss = MeanMetric() 103 self.val_loss = MeanMetric() 104 self.test_loss = MeanMetric() 105 # metric object for calculating and averaging ADIC or ROAD or across batches 106 107 # for tracking best so far validation accuracy 108 self.val_acc_best = MaxMetric() 109 110 @staticmethod 111 def mse_loss(student_logits, teacher_logits): 112 return torch.nn.functional.mse_loss(student_logits, teacher_logits) 113 114 @staticmethod 115 def kl_loss(student_logits, teacher_logits, temperature=5.0): 116 p_t = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1) 117 p_s = torch.nn.functional.softmax(student_logits / temperature, dim=-1) 118 return torch.nn.functional.kl_div( 119 p_s.log(), p_t, reduction="batchmean" 120 ) 121 122 def convert_fc(self) -> None: 123 """Convert the last layer of the model to a fully connected layer.""" 124 if hasattr(self.net, "fc"): 125 in_features = self.net.fc.in_features 126 out_features = ( 127 self.trainer.datamodule.num_classes 128 if self.teacher is None 129 else self.teacher.fc.out_features 130 ) 131 self.classes = out_features 132 if self.net.fc.out_features == out_features: 133 return 134 self.net.fc = torch.nn.Linear(in_features, out_features) 135 # initialize weights 136 torch.nn.init.xavier_uniform_(self.net.fc.weight) 137 if self.net.fc.bias is not None: 138 torch.nn.init.zeros_(self.net.fc.bias) 139 140 def forward(self, x: torch.Tensor) -> torch.Tensor: 141 """Perform a forward pass through the model `self.net`. 142 143 :param x: A tensor of images. 144 :return: A tensor of logits. 145 """ 146 if self.hparams.compile: 147 return self.compiled_net(x) 148 return self.net(x) 149 150 def on_train_start(self) -> None: 151 """Lightning hook that is called when training begins.""" 152 # by default lightning executes validation step sanity checks before training starts, 153 # so it's worth to make sure validation metrics don't store results from these checks 154 self.val_loss.reset() 155 self.val_acc.reset() 156 self.val_acc_best.reset() 157 158 def model_step( 159 self, batch: tuple[torch.Tensor, torch.Tensor] 160 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 161 """Perform a single model step on a batch of data. 162 163 :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. 164 165 :return: A tuple containing (in order): 166 - A tensor of losses. 167 - A tensor of predictions. 168 - A tensor of target labels. 169 """ 170 x, y = batch 171 target = y 172 if self.teacher is not None: 173 if self.hparams.compile: 174 y = self.compiled_teacher(x) 175 else: 176 y = self.teacher(x) 177 target = y.argmax(dim=1) 178 out = self.forward(x) 179 loss = self.criterion(out, y) 180 return loss, out, target 181 182 def training_step( 183 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 184 ) -> torch.Tensor: 185 """Perform a single training step on a batch of data from the training 186 set. 187 188 :param batch: A batch of data (a tuple) containing the input tensor of 189 images and target labels. 190 :param batch_idx: The index of the current batch. 191 :return: A tensor of losses between model predictions and targets. 192 """ 193 loss, out, target = self.model_step(batch) 194 preds = out.argmax(dim=1) 195 # update and log metrics 196 self.train_loss(loss) # compute metric 197 self.log("train/loss", self.train_loss, prog_bar=True) 198 self.train_acc(preds, target) 199 self.log("train/acc", self.train_acc, prog_bar=True) 200 201 # return loss or backpropagation will fail 202 if self.hparams.terminate_on_nan and loss.isnan().any(): 203 raise ValueError("NaN detected in loss!") 204 return loss 205 206 def validation_step( 207 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 208 ) -> None: 209 """Perform a single validation step on a batch of data from the 210 validation set. 211 212 :param batch: A batch of data (a tuple) containing the input tensor of 213 images and target labels. 214 :param batch_idx: The index of the current batch. 215 """ 216 loss, out, target = self.model_step(batch) 217 preds = out.argmax(dim=1) 218 219 # update and log metrics 220 self.val_loss(loss) 221 self.log("val/loss", self.val_loss) 222 self.val_acc(preds, target) 223 self.log("val/acc", self.val_acc) 224 225 def on_validation_epoch_end(self) -> None: 226 "Lightning hook that is called when a validation epoch ends." 227 acc = self.val_acc.compute() # get current val acc 228 self.val_acc_best(acc) # update best so far val acc 229 # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object 230 # otherwise metric would be reset by lightning after each epoch 231 self.log( 232 "val/acc_best", 233 self.val_acc_best.compute(), 234 sync_dist=True, 235 prog_bar=True, 236 ) 237 238 def test_step( 239 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 240 ) -> None: 241 """Perform a single test step on a batch of data from the test set. 242 243 :param batch: A batch of data (a tuple) containing the input tensor of 244 images and target labels. 245 :param batch_idx: The index of the current batch. 246 """ 247 loss, out, target = self.model_step(batch) 248 preds = out.argmax(dim=1) 249 250 # update and log metrics 251 self.test_loss(loss) 252 self.log("test/loss", self.test_loss) 253 self.test_acc(preds, target) 254 self.log("test/acc", self.test_acc) 255 256 def setup(self, stage: str) -> None: 257 """Lightning hook that is called at the beginning of fit (train + 258 validate), validate, test, or predict. 259 260 This is a good hook when you need to build models dynamically or adjust 261 something about them. This hook is called on every process when using 262 DDP. 263 264 :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 265 """ 266 self.convert_fc() 267 if self.hparams.compile: 268 self.compiled_net = torch.compile(self.net) 269 if self.teacher is not None: 270 self.compiled_teacher = torch.compile(self.teacher) 271 # metric objects for calculating and averaging accuracy across batches 272 self.train_acc = Accuracy(task="multiclass", num_classes=self.classes) 273 self.val_acc = Accuracy(task="multiclass", num_classes=self.classes) 274 self.test_acc = Accuracy(task="multiclass", num_classes=self.classes) 275 276 def teardown(self, stage: str) -> None: 277 """Lightning hook that is called at the end of fit (train + validate), 278 validate, test, or predict. 279 280 This is a good hook when you need to clean something up after the run. 281 282 :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 283 """ 284 if self.hparams.compile: 285 del self.compiled_net 286 if self.teacher: 287 del self.compiled_teacher 288 289 def configure_optimizers(self) -> dict[str, Any]: 290 """Choose what optimizers and learning-rate schedulers to use in your 291 optimization. Normally you'd need one. But in the case of GANs or 292 similar you might have multiple. 293 294 Examples: 295 https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers 296 297 :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. 298 """ 299 optimizer = self.optimizer_cfg(params=self.trainer.model.parameters()) 300 if (scheduler_info := self.scheduler_cfg) is not None: 301 scheduler = scheduler_info["scheduler"] 302 if scheduler_info["name"] == "OneCycleLR": 303 scheduler = partial( 304 scheduler, 305 total_steps=self.trainer.estimated_stepping_batches, 306 ) 307 scheduler = scheduler(optimizer=optimizer) 308 scheduler_params = scheduler_info["params"] 309 return { 310 "optimizer": optimizer, 311 "lr_scheduler": { 312 "scheduler": scheduler, 313 "interval": scheduler_params["interval"], 314 "frequency": scheduler_params["frequency"], 315 "name": scheduler_params["name"], 316 }, 317 } 318 return {"optimizer": optimizer} 319 320 def on_save_checkpoint(self, checkpoint): 321 # pop the keys you are not interested by 322 sd = checkpoint["state_dict"] 323 names = list(sd.keys()) 324 for name in names: 325 if any(x in name for x in ["_orig_mod", "fc", "teacher"]): 326 sd.pop(name) 327 328 329if __name__ == "__main__": 330 _ = CVModule(None, None, None, None, None, None)
13class CVModule(LightningModule): 14 """`LightningModule` for PAMELA. 15 16 A `LightningModule` implements 8 key methods: 17 18 ```python 19 def __init__(self): 20 # Define initialization code here. 21 22 def setup(self, stage): 23 # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'. 24 # This hook is called on every process when using DDP. 25 26 def training_step(self, batch, batch_idx): 27 # The complete training step. 28 29 def validation_step(self, batch, batch_idx): 30 # The complete validation step. 31 32 def test_step(self, batch, batch_idx): 33 # The complete test step. 34 35 def predict_step(self, batch, batch_idx): 36 # The complete predict step. 37 38 def configure_optimizers(self): 39 # Define and configure optimizers and LR schedulers. 40 ``` 41 42 Docs: 43 https://lightning.ai/docs/pytorch/latest/common/lightning_module.html 44 """ 45 46 def __init__( 47 self, 48 net: torch.nn.Module, 49 loss: Callable[[Tensor], Tensor], 50 optimizer: Callable[ 51 [Iterator[torch.nn.Parameter]], torch.optim.Optimizer 52 ], 53 scheduler: dict[str, Any] | None, 54 teacher: Callable[[Tensor], Tensor] | None = None, 55 terminate_on_nan: bool = True, 56 compile: bool = False, 57 **kwargs, 58 ) -> None: 59 """Initialize a `PAMELALitModule`. 60 61 :param net: The model to train. 62 :param optimizer: The optimizer to use for training. 63 :param scheduler: The learning rate scheduler to use for training, 64 together with the options (how often it should update etc.). 65 :param val_measures: The explainability specific validation measures to 66 use. 67 :param test_measures: " test ". 68 :param compile: Whether to compile the model before training. 69 """ 70 super().__init__() 71 72 # this line allows to access init params with 'self.hparams' attribute 73 # also ensures init params will be stored in ckpt 74 self.save_hyperparameters( 75 ignore=["net", "loss", "optimizer", "scheduler", "teacher"], 76 logger=False, 77 ) 78 79 self.net = net 80 self.teacher = None 81 if teacher is not None: 82 self.teacher = teacher 83 loss = self.kl_loss 84 self.teacher.eval() 85 self.teacher.requires_grad_(False) 86 87 # cut off initialization for simple restore 88 if loss is None: 89 return 90 91 # loss function 92 self.criterion = loss 93 94 # optimizer and scheduler 95 self.optimizer_cfg = optimizer 96 self.scheduler_cfg = scheduler 97 98 # metric objects are initialized in `setup` method 99 self.train_acc = None 100 self.val_acc = None 101 self.test_acc = None 102 # for averaging loss across batches 103 self.train_loss = MeanMetric() 104 self.val_loss = MeanMetric() 105 self.test_loss = MeanMetric() 106 # metric object for calculating and averaging ADIC or ROAD or across batches 107 108 # for tracking best so far validation accuracy 109 self.val_acc_best = MaxMetric() 110 111 @staticmethod 112 def mse_loss(student_logits, teacher_logits): 113 return torch.nn.functional.mse_loss(student_logits, teacher_logits) 114 115 @staticmethod 116 def kl_loss(student_logits, teacher_logits, temperature=5.0): 117 p_t = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1) 118 p_s = torch.nn.functional.softmax(student_logits / temperature, dim=-1) 119 return torch.nn.functional.kl_div( 120 p_s.log(), p_t, reduction="batchmean" 121 ) 122 123 def convert_fc(self) -> None: 124 """Convert the last layer of the model to a fully connected layer.""" 125 if hasattr(self.net, "fc"): 126 in_features = self.net.fc.in_features 127 out_features = ( 128 self.trainer.datamodule.num_classes 129 if self.teacher is None 130 else self.teacher.fc.out_features 131 ) 132 self.classes = out_features 133 if self.net.fc.out_features == out_features: 134 return 135 self.net.fc = torch.nn.Linear(in_features, out_features) 136 # initialize weights 137 torch.nn.init.xavier_uniform_(self.net.fc.weight) 138 if self.net.fc.bias is not None: 139 torch.nn.init.zeros_(self.net.fc.bias) 140 141 def forward(self, x: torch.Tensor) -> torch.Tensor: 142 """Perform a forward pass through the model `self.net`. 143 144 :param x: A tensor of images. 145 :return: A tensor of logits. 146 """ 147 if self.hparams.compile: 148 return self.compiled_net(x) 149 return self.net(x) 150 151 def on_train_start(self) -> None: 152 """Lightning hook that is called when training begins.""" 153 # by default lightning executes validation step sanity checks before training starts, 154 # so it's worth to make sure validation metrics don't store results from these checks 155 self.val_loss.reset() 156 self.val_acc.reset() 157 self.val_acc_best.reset() 158 159 def model_step( 160 self, batch: tuple[torch.Tensor, torch.Tensor] 161 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 162 """Perform a single model step on a batch of data. 163 164 :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. 165 166 :return: A tuple containing (in order): 167 - A tensor of losses. 168 - A tensor of predictions. 169 - A tensor of target labels. 170 """ 171 x, y = batch 172 target = y 173 if self.teacher is not None: 174 if self.hparams.compile: 175 y = self.compiled_teacher(x) 176 else: 177 y = self.teacher(x) 178 target = y.argmax(dim=1) 179 out = self.forward(x) 180 loss = self.criterion(out, y) 181 return loss, out, target 182 183 def training_step( 184 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 185 ) -> torch.Tensor: 186 """Perform a single training step on a batch of data from the training 187 set. 188 189 :param batch: A batch of data (a tuple) containing the input tensor of 190 images and target labels. 191 :param batch_idx: The index of the current batch. 192 :return: A tensor of losses between model predictions and targets. 193 """ 194 loss, out, target = self.model_step(batch) 195 preds = out.argmax(dim=1) 196 # update and log metrics 197 self.train_loss(loss) # compute metric 198 self.log("train/loss", self.train_loss, prog_bar=True) 199 self.train_acc(preds, target) 200 self.log("train/acc", self.train_acc, prog_bar=True) 201 202 # return loss or backpropagation will fail 203 if self.hparams.terminate_on_nan and loss.isnan().any(): 204 raise ValueError("NaN detected in loss!") 205 return loss 206 207 def validation_step( 208 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 209 ) -> None: 210 """Perform a single validation step on a batch of data from the 211 validation set. 212 213 :param batch: A batch of data (a tuple) containing the input tensor of 214 images and target labels. 215 :param batch_idx: The index of the current batch. 216 """ 217 loss, out, target = self.model_step(batch) 218 preds = out.argmax(dim=1) 219 220 # update and log metrics 221 self.val_loss(loss) 222 self.log("val/loss", self.val_loss) 223 self.val_acc(preds, target) 224 self.log("val/acc", self.val_acc) 225 226 def on_validation_epoch_end(self) -> None: 227 "Lightning hook that is called when a validation epoch ends." 228 acc = self.val_acc.compute() # get current val acc 229 self.val_acc_best(acc) # update best so far val acc 230 # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object 231 # otherwise metric would be reset by lightning after each epoch 232 self.log( 233 "val/acc_best", 234 self.val_acc_best.compute(), 235 sync_dist=True, 236 prog_bar=True, 237 ) 238 239 def test_step( 240 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 241 ) -> None: 242 """Perform a single test step on a batch of data from the test set. 243 244 :param batch: A batch of data (a tuple) containing the input tensor of 245 images and target labels. 246 :param batch_idx: The index of the current batch. 247 """ 248 loss, out, target = self.model_step(batch) 249 preds = out.argmax(dim=1) 250 251 # update and log metrics 252 self.test_loss(loss) 253 self.log("test/loss", self.test_loss) 254 self.test_acc(preds, target) 255 self.log("test/acc", self.test_acc) 256 257 def setup(self, stage: str) -> None: 258 """Lightning hook that is called at the beginning of fit (train + 259 validate), validate, test, or predict. 260 261 This is a good hook when you need to build models dynamically or adjust 262 something about them. This hook is called on every process when using 263 DDP. 264 265 :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 266 """ 267 self.convert_fc() 268 if self.hparams.compile: 269 self.compiled_net = torch.compile(self.net) 270 if self.teacher is not None: 271 self.compiled_teacher = torch.compile(self.teacher) 272 # metric objects for calculating and averaging accuracy across batches 273 self.train_acc = Accuracy(task="multiclass", num_classes=self.classes) 274 self.val_acc = Accuracy(task="multiclass", num_classes=self.classes) 275 self.test_acc = Accuracy(task="multiclass", num_classes=self.classes) 276 277 def teardown(self, stage: str) -> None: 278 """Lightning hook that is called at the end of fit (train + validate), 279 validate, test, or predict. 280 281 This is a good hook when you need to clean something up after the run. 282 283 :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 284 """ 285 if self.hparams.compile: 286 del self.compiled_net 287 if self.teacher: 288 del self.compiled_teacher 289 290 def configure_optimizers(self) -> dict[str, Any]: 291 """Choose what optimizers and learning-rate schedulers to use in your 292 optimization. Normally you'd need one. But in the case of GANs or 293 similar you might have multiple. 294 295 Examples: 296 https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers 297 298 :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. 299 """ 300 optimizer = self.optimizer_cfg(params=self.trainer.model.parameters()) 301 if (scheduler_info := self.scheduler_cfg) is not None: 302 scheduler = scheduler_info["scheduler"] 303 if scheduler_info["name"] == "OneCycleLR": 304 scheduler = partial( 305 scheduler, 306 total_steps=self.trainer.estimated_stepping_batches, 307 ) 308 scheduler = scheduler(optimizer=optimizer) 309 scheduler_params = scheduler_info["params"] 310 return { 311 "optimizer": optimizer, 312 "lr_scheduler": { 313 "scheduler": scheduler, 314 "interval": scheduler_params["interval"], 315 "frequency": scheduler_params["frequency"], 316 "name": scheduler_params["name"], 317 }, 318 } 319 return {"optimizer": optimizer} 320 321 def on_save_checkpoint(self, checkpoint): 322 # pop the keys you are not interested by 323 sd = checkpoint["state_dict"] 324 names = list(sd.keys()) 325 for name in names: 326 if any(x in name for x in ["_orig_mod", "fc", "teacher"]): 327 sd.pop(name)
LightningModule for PAMELA.
A LightningModule implements 8 key methods:
def __init__(self):
# Define initialization code here.
def setup(self, stage):
# Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
# This hook is called on every process when using DDP.
def training_step(self, batch, batch_idx):
# The complete training step.
def validation_step(self, batch, batch_idx):
# The complete validation step.
def test_step(self, batch, batch_idx):
# The complete test step.
def predict_step(self, batch, batch_idx):
# The complete predict step.
def configure_optimizers(self):
# Define and configure optimizers and LR schedulers.
Docs: https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
46 def __init__( 47 self, 48 net: torch.nn.Module, 49 loss: Callable[[Tensor], Tensor], 50 optimizer: Callable[ 51 [Iterator[torch.nn.Parameter]], torch.optim.Optimizer 52 ], 53 scheduler: dict[str, Any] | None, 54 teacher: Callable[[Tensor], Tensor] | None = None, 55 terminate_on_nan: bool = True, 56 compile: bool = False, 57 **kwargs, 58 ) -> None: 59 """Initialize a `PAMELALitModule`. 60 61 :param net: The model to train. 62 :param optimizer: The optimizer to use for training. 63 :param scheduler: The learning rate scheduler to use for training, 64 together with the options (how often it should update etc.). 65 :param val_measures: The explainability specific validation measures to 66 use. 67 :param test_measures: " test ". 68 :param compile: Whether to compile the model before training. 69 """ 70 super().__init__() 71 72 # this line allows to access init params with 'self.hparams' attribute 73 # also ensures init params will be stored in ckpt 74 self.save_hyperparameters( 75 ignore=["net", "loss", "optimizer", "scheduler", "teacher"], 76 logger=False, 77 ) 78 79 self.net = net 80 self.teacher = None 81 if teacher is not None: 82 self.teacher = teacher 83 loss = self.kl_loss 84 self.teacher.eval() 85 self.teacher.requires_grad_(False) 86 87 # cut off initialization for simple restore 88 if loss is None: 89 return 90 91 # loss function 92 self.criterion = loss 93 94 # optimizer and scheduler 95 self.optimizer_cfg = optimizer 96 self.scheduler_cfg = scheduler 97 98 # metric objects are initialized in `setup` method 99 self.train_acc = None 100 self.val_acc = None 101 self.test_acc = None 102 # for averaging loss across batches 103 self.train_loss = MeanMetric() 104 self.val_loss = MeanMetric() 105 self.test_loss = MeanMetric() 106 # metric object for calculating and averaging ADIC or ROAD or across batches 107 108 # for tracking best so far validation accuracy 109 self.val_acc_best = MaxMetric()
Initialize a PAMELALitModule.
Parameters
- net: The model to train.
- optimizer: The optimizer to use for training.
- scheduler: The learning rate scheduler to use for training, together with the options (how often it should update etc.).
- val_measures: The explainability specific validation measures to use.
- test_measures: " test ".
- compile: Whether to compile the model before training.
115 @staticmethod 116 def kl_loss(student_logits, teacher_logits, temperature=5.0): 117 p_t = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1) 118 p_s = torch.nn.functional.softmax(student_logits / temperature, dim=-1) 119 return torch.nn.functional.kl_div( 120 p_s.log(), p_t, reduction="batchmean" 121 )
123 def convert_fc(self) -> None: 124 """Convert the last layer of the model to a fully connected layer.""" 125 if hasattr(self.net, "fc"): 126 in_features = self.net.fc.in_features 127 out_features = ( 128 self.trainer.datamodule.num_classes 129 if self.teacher is None 130 else self.teacher.fc.out_features 131 ) 132 self.classes = out_features 133 if self.net.fc.out_features == out_features: 134 return 135 self.net.fc = torch.nn.Linear(in_features, out_features) 136 # initialize weights 137 torch.nn.init.xavier_uniform_(self.net.fc.weight) 138 if self.net.fc.bias is not None: 139 torch.nn.init.zeros_(self.net.fc.bias)
Convert the last layer of the model to a fully connected layer.
141 def forward(self, x: torch.Tensor) -> torch.Tensor: 142 """Perform a forward pass through the model `self.net`. 143 144 :param x: A tensor of images. 145 :return: A tensor of logits. 146 """ 147 if self.hparams.compile: 148 return self.compiled_net(x) 149 return self.net(x)
Perform a forward pass through the model self.net.
Parameters
- x: A tensor of images.
Returns
A tensor of logits.
151 def on_train_start(self) -> None: 152 """Lightning hook that is called when training begins.""" 153 # by default lightning executes validation step sanity checks before training starts, 154 # so it's worth to make sure validation metrics don't store results from these checks 155 self.val_loss.reset() 156 self.val_acc.reset() 157 self.val_acc_best.reset()
Lightning hook that is called when training begins.
159 def model_step( 160 self, batch: tuple[torch.Tensor, torch.Tensor] 161 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 162 """Perform a single model step on a batch of data. 163 164 :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. 165 166 :return: A tuple containing (in order): 167 - A tensor of losses. 168 - A tensor of predictions. 169 - A tensor of target labels. 170 """ 171 x, y = batch 172 target = y 173 if self.teacher is not None: 174 if self.hparams.compile: 175 y = self.compiled_teacher(x) 176 else: 177 y = self.teacher(x) 178 target = y.argmax(dim=1) 179 out = self.forward(x) 180 loss = self.criterion(out, y) 181 return loss, out, target
Perform a single model step on a batch of data.
Parameters
- batch: A batch of data (a tuple) containing the input tensor of images and target labels.
Returns
A tuple containing (in order): - A tensor of losses. - A tensor of predictions. - A tensor of target labels.
183 def training_step( 184 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 185 ) -> torch.Tensor: 186 """Perform a single training step on a batch of data from the training 187 set. 188 189 :param batch: A batch of data (a tuple) containing the input tensor of 190 images and target labels. 191 :param batch_idx: The index of the current batch. 192 :return: A tensor of losses between model predictions and targets. 193 """ 194 loss, out, target = self.model_step(batch) 195 preds = out.argmax(dim=1) 196 # update and log metrics 197 self.train_loss(loss) # compute metric 198 self.log("train/loss", self.train_loss, prog_bar=True) 199 self.train_acc(preds, target) 200 self.log("train/acc", self.train_acc, prog_bar=True) 201 202 # return loss or backpropagation will fail 203 if self.hparams.terminate_on_nan and loss.isnan().any(): 204 raise ValueError("NaN detected in loss!") 205 return loss
Perform a single training step on a batch of data from the training set.
Parameters
- batch: A batch of data (a tuple) containing the input tensor of images and target labels.
- batch_idx: The index of the current batch.
Returns
A tensor of losses between model predictions and targets.
207 def validation_step( 208 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 209 ) -> None: 210 """Perform a single validation step on a batch of data from the 211 validation set. 212 213 :param batch: A batch of data (a tuple) containing the input tensor of 214 images and target labels. 215 :param batch_idx: The index of the current batch. 216 """ 217 loss, out, target = self.model_step(batch) 218 preds = out.argmax(dim=1) 219 220 # update and log metrics 221 self.val_loss(loss) 222 self.log("val/loss", self.val_loss) 223 self.val_acc(preds, target) 224 self.log("val/acc", self.val_acc)
Perform a single validation step on a batch of data from the validation set.
Parameters
- batch: A batch of data (a tuple) containing the input tensor of images and target labels.
- batch_idx: The index of the current batch.
226 def on_validation_epoch_end(self) -> None: 227 "Lightning hook that is called when a validation epoch ends." 228 acc = self.val_acc.compute() # get current val acc 229 self.val_acc_best(acc) # update best so far val acc 230 # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object 231 # otherwise metric would be reset by lightning after each epoch 232 self.log( 233 "val/acc_best", 234 self.val_acc_best.compute(), 235 sync_dist=True, 236 prog_bar=True, 237 )
Lightning hook that is called when a validation epoch ends.
239 def test_step( 240 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 241 ) -> None: 242 """Perform a single test step on a batch of data from the test set. 243 244 :param batch: A batch of data (a tuple) containing the input tensor of 245 images and target labels. 246 :param batch_idx: The index of the current batch. 247 """ 248 loss, out, target = self.model_step(batch) 249 preds = out.argmax(dim=1) 250 251 # update and log metrics 252 self.test_loss(loss) 253 self.log("test/loss", self.test_loss) 254 self.test_acc(preds, target) 255 self.log("test/acc", self.test_acc)
Perform a single test step on a batch of data from the test set.
Parameters
- batch: A batch of data (a tuple) containing the input tensor of images and target labels.
- batch_idx: The index of the current batch.
257 def setup(self, stage: str) -> None: 258 """Lightning hook that is called at the beginning of fit (train + 259 validate), validate, test, or predict. 260 261 This is a good hook when you need to build models dynamically or adjust 262 something about them. This hook is called on every process when using 263 DDP. 264 265 :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 266 """ 267 self.convert_fc() 268 if self.hparams.compile: 269 self.compiled_net = torch.compile(self.net) 270 if self.teacher is not None: 271 self.compiled_teacher = torch.compile(self.teacher) 272 # metric objects for calculating and averaging accuracy across batches 273 self.train_acc = Accuracy(task="multiclass", num_classes=self.classes) 274 self.val_acc = Accuracy(task="multiclass", num_classes=self.classes) 275 self.test_acc = Accuracy(task="multiclass", num_classes=self.classes)
Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.
This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
Parameters
- stage: Either
"fit","validate","test", or"predict".
277 def teardown(self, stage: str) -> None: 278 """Lightning hook that is called at the end of fit (train + validate), 279 validate, test, or predict. 280 281 This is a good hook when you need to clean something up after the run. 282 283 :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 284 """ 285 if self.hparams.compile: 286 del self.compiled_net 287 if self.teacher: 288 del self.compiled_teacher
Lightning hook that is called at the end of fit (train + validate), validate, test, or predict.
This is a good hook when you need to clean something up after the run.
Parameters
- stage: Either
"fit","validate","test", or"predict".
290 def configure_optimizers(self) -> dict[str, Any]: 291 """Choose what optimizers and learning-rate schedulers to use in your 292 optimization. Normally you'd need one. But in the case of GANs or 293 similar you might have multiple. 294 295 Examples: 296 https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers 297 298 :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. 299 """ 300 optimizer = self.optimizer_cfg(params=self.trainer.model.parameters()) 301 if (scheduler_info := self.scheduler_cfg) is not None: 302 scheduler = scheduler_info["scheduler"] 303 if scheduler_info["name"] == "OneCycleLR": 304 scheduler = partial( 305 scheduler, 306 total_steps=self.trainer.estimated_stepping_batches, 307 ) 308 scheduler = scheduler(optimizer=optimizer) 309 scheduler_params = scheduler_info["params"] 310 return { 311 "optimizer": optimizer, 312 "lr_scheduler": { 313 "scheduler": scheduler, 314 "interval": scheduler_params["interval"], 315 "frequency": scheduler_params["frequency"], 316 "name": scheduler_params["name"], 317 }, 318 } 319 return {"optimizer": optimizer}
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple.
Examples: https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
Returns
A dict containing the configured optimizers and learning-rate schedulers to be used for training.
321 def on_save_checkpoint(self, checkpoint): 322 # pop the keys you are not interested by 323 sd = checkpoint["state_dict"] 324 names = list(sd.keys()) 325 for name in names: 326 if any(x in name for x in ["_orig_mod", "fc", "teacher"]): 327 sd.pop(name)
Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
Args: checkpoint: The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.
Example::
def on_save_checkpoint(self, checkpoint):
# 99% of use cases you don't need to implement this method
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note: Lightning saves all aspects of training (epoch, global step, etc...) including amp scaling. There is no need for you to store anything about training.