ptame.models.cv_module

  1from collections.abc import Callable, Iterator
  2from functools import partial
  3from typing import Any
  4
  5import torch
  6from lightning import LightningModule
  7from torch import Tensor
  8from torchmetrics import MaxMetric, MeanMetric
  9from torchmetrics.classification.accuracy import Accuracy
 10
 11
 12class CVModule(LightningModule):
 13    """`LightningModule` for PAMELA.
 14
 15    A `LightningModule` implements 8 key methods:
 16
 17    ```python
 18    def __init__(self):
 19    # Define initialization code here.
 20
 21    def setup(self, stage):
 22    # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
 23    # This hook is called on every process when using DDP.
 24
 25    def training_step(self, batch, batch_idx):
 26    # The complete training step.
 27
 28    def validation_step(self, batch, batch_idx):
 29    # The complete validation step.
 30
 31    def test_step(self, batch, batch_idx):
 32    # The complete test step.
 33
 34    def predict_step(self, batch, batch_idx):
 35    # The complete predict step.
 36
 37    def configure_optimizers(self):
 38    # Define and configure optimizers and LR schedulers.
 39    ```
 40
 41    Docs:
 42        https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
 43    """
 44
 45    def __init__(
 46        self,
 47        net: torch.nn.Module,
 48        loss: Callable[[Tensor], Tensor],
 49        optimizer: Callable[
 50            [Iterator[torch.nn.Parameter]], torch.optim.Optimizer
 51        ],
 52        scheduler: dict[str, Any] | None,
 53        teacher: Callable[[Tensor], Tensor] | None = None,
 54        terminate_on_nan: bool = True,
 55        compile: bool = False,
 56        **kwargs,
 57    ) -> None:
 58        """Initialize a `PAMELALitModule`.
 59
 60        :param net: The model to train.
 61        :param optimizer: The optimizer to use for training.
 62        :param scheduler: The learning rate scheduler to use for training,
 63            together with the options (how often it should update etc.).
 64        :param val_measures: The explainability specific validation measures to
 65            use.
 66        :param test_measures: " test ".
 67        :param compile: Whether to compile the model before training.
 68        """
 69        super().__init__()
 70
 71        # this line allows to access init params with 'self.hparams' attribute
 72        # also ensures init params will be stored in ckpt
 73        self.save_hyperparameters(
 74            ignore=["net", "loss", "optimizer", "scheduler", "teacher"],
 75            logger=False,
 76        )
 77
 78        self.net = net
 79        self.teacher = None
 80        if teacher is not None:
 81            self.teacher = teacher
 82            loss = self.kl_loss
 83            self.teacher.eval()
 84            self.teacher.requires_grad_(False)
 85
 86        # cut off initialization for simple restore
 87        if loss is None:
 88            return
 89
 90        # loss function
 91        self.criterion = loss
 92
 93        # optimizer and scheduler
 94        self.optimizer_cfg = optimizer
 95        self.scheduler_cfg = scheduler
 96
 97        # metric objects are initialized in `setup` method
 98        self.train_acc = None
 99        self.val_acc = None
100        self.test_acc = None
101        # for averaging loss across batches
102        self.train_loss = MeanMetric()
103        self.val_loss = MeanMetric()
104        self.test_loss = MeanMetric()
105        # metric object for calculating and averaging ADIC or ROAD or across batches
106
107        # for tracking best so far validation accuracy
108        self.val_acc_best = MaxMetric()
109
110    @staticmethod
111    def mse_loss(student_logits, teacher_logits):
112        return torch.nn.functional.mse_loss(student_logits, teacher_logits)
113
114    @staticmethod
115    def kl_loss(student_logits, teacher_logits, temperature=5.0):
116        p_t = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)
117        p_s = torch.nn.functional.softmax(student_logits / temperature, dim=-1)
118        return torch.nn.functional.kl_div(
119            p_s.log(), p_t, reduction="batchmean"
120        )
121
122    def convert_fc(self) -> None:
123        """Convert the last layer of the model to a fully connected layer."""
124        if hasattr(self.net, "fc"):
125            in_features = self.net.fc.in_features
126            out_features = (
127                self.trainer.datamodule.num_classes
128                if self.teacher is None
129                else self.teacher.fc.out_features
130            )
131            self.classes = out_features
132            if self.net.fc.out_features == out_features:
133                return
134            self.net.fc = torch.nn.Linear(in_features, out_features)
135            # initialize weights
136            torch.nn.init.xavier_uniform_(self.net.fc.weight)
137            if self.net.fc.bias is not None:
138                torch.nn.init.zeros_(self.net.fc.bias)
139
140    def forward(self, x: torch.Tensor) -> torch.Tensor:
141        """Perform a forward pass through the model `self.net`.
142
143        :param x: A tensor of images.
144        :return: A tensor of logits.
145        """
146        if self.hparams.compile:
147            return self.compiled_net(x)
148        return self.net(x)
149
150    def on_train_start(self) -> None:
151        """Lightning hook that is called when training begins."""
152        # by default lightning executes validation step sanity checks before training starts,
153        # so it's worth to make sure validation metrics don't store results from these checks
154        self.val_loss.reset()
155        self.val_acc.reset()
156        self.val_acc_best.reset()
157
158    def model_step(
159        self, batch: tuple[torch.Tensor, torch.Tensor]
160    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
161        """Perform a single model step on a batch of data.
162
163        :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
164
165        :return: A tuple containing (in order):
166            - A tensor of losses.
167            - A tensor of predictions.
168            - A tensor of target labels.
169        """
170        x, y = batch
171        target = y
172        if self.teacher is not None:
173            if self.hparams.compile:
174                y = self.compiled_teacher(x)
175            else:
176                y = self.teacher(x)
177            target = y.argmax(dim=1)
178        out = self.forward(x)
179        loss = self.criterion(out, y)
180        return loss, out, target
181
182    def training_step(
183        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
184    ) -> torch.Tensor:
185        """Perform a single training step on a batch of data from the training
186        set.
187
188        :param batch: A batch of data (a tuple) containing the input tensor of
189            images and target labels.
190        :param batch_idx: The index of the current batch.
191        :return: A tensor of losses between model predictions and targets.
192        """
193        loss, out, target = self.model_step(batch)
194        preds = out.argmax(dim=1)
195        # update and log metrics
196        self.train_loss(loss)  # compute metric
197        self.log("train/loss", self.train_loss, prog_bar=True)
198        self.train_acc(preds, target)
199        self.log("train/acc", self.train_acc, prog_bar=True)
200
201        # return loss or backpropagation will fail
202        if self.hparams.terminate_on_nan and loss.isnan().any():
203            raise ValueError("NaN detected in loss!")
204        return loss
205
206    def validation_step(
207        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
208    ) -> None:
209        """Perform a single validation step on a batch of data from the
210        validation set.
211
212        :param batch: A batch of data (a tuple) containing the input tensor of
213            images and target labels.
214        :param batch_idx: The index of the current batch.
215        """
216        loss, out, target = self.model_step(batch)
217        preds = out.argmax(dim=1)
218
219        # update and log metrics
220        self.val_loss(loss)
221        self.log("val/loss", self.val_loss)
222        self.val_acc(preds, target)
223        self.log("val/acc", self.val_acc)
224
225    def on_validation_epoch_end(self) -> None:
226        "Lightning hook that is called when a validation epoch ends."
227        acc = self.val_acc.compute()  # get current val acc
228        self.val_acc_best(acc)  # update best so far val acc
229        # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
230        # otherwise metric would be reset by lightning after each epoch
231        self.log(
232            "val/acc_best",
233            self.val_acc_best.compute(),
234            sync_dist=True,
235            prog_bar=True,
236        )
237
238    def test_step(
239        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
240    ) -> None:
241        """Perform a single test step on a batch of data from the test set.
242
243        :param batch: A batch of data (a tuple) containing the input tensor of
244            images and target labels.
245        :param batch_idx: The index of the current batch.
246        """
247        loss, out, target = self.model_step(batch)
248        preds = out.argmax(dim=1)
249
250        # update and log metrics
251        self.test_loss(loss)
252        self.log("test/loss", self.test_loss)
253        self.test_acc(preds, target)
254        self.log("test/acc", self.test_acc)
255
256    def setup(self, stage: str) -> None:
257        """Lightning hook that is called at the beginning of fit (train +
258        validate), validate, test, or predict.
259
260        This is a good hook when you need to build models dynamically or adjust
261        something about them. This hook is called on every process when using
262        DDP.
263
264        :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
265        """
266        self.convert_fc()
267        if self.hparams.compile:
268            self.compiled_net = torch.compile(self.net)
269            if self.teacher is not None:
270                self.compiled_teacher = torch.compile(self.teacher)
271        # metric objects for calculating and averaging accuracy across batches
272        self.train_acc = Accuracy(task="multiclass", num_classes=self.classes)
273        self.val_acc = Accuracy(task="multiclass", num_classes=self.classes)
274        self.test_acc = Accuracy(task="multiclass", num_classes=self.classes)
275
276    def teardown(self, stage: str) -> None:
277        """Lightning hook that is called at the end of fit (train + validate),
278        validate, test, or predict.
279
280        This is a good hook when you need to clean something up after the run.
281
282        :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
283        """
284        if self.hparams.compile:
285            del self.compiled_net
286            if self.teacher:
287                del self.compiled_teacher
288
289    def configure_optimizers(self) -> dict[str, Any]:
290        """Choose what optimizers and learning-rate schedulers to use in your
291        optimization. Normally you'd need one. But in the case of GANs or
292        similar you might have multiple.
293
294        Examples:
295            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
296
297        :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
298        """
299        optimizer = self.optimizer_cfg(params=self.trainer.model.parameters())
300        if (scheduler_info := self.scheduler_cfg) is not None:
301            scheduler = scheduler_info["scheduler"]
302            if scheduler_info["name"] == "OneCycleLR":
303                scheduler = partial(
304                    scheduler,
305                    total_steps=self.trainer.estimated_stepping_batches,
306                )
307            scheduler = scheduler(optimizer=optimizer)
308            scheduler_params = scheduler_info["params"]
309            return {
310                "optimizer": optimizer,
311                "lr_scheduler": {
312                    "scheduler": scheduler,
313                    "interval": scheduler_params["interval"],
314                    "frequency": scheduler_params["frequency"],
315                    "name": scheduler_params["name"],
316                },
317            }
318        return {"optimizer": optimizer}
319
320    def on_save_checkpoint(self, checkpoint):
321        # pop the keys you are not interested by
322        sd = checkpoint["state_dict"]
323        names = list(sd.keys())
324        for name in names:
325            if any(x in name for x in ["_orig_mod", "fc", "teacher"]):
326                sd.pop(name)
327
328
329if __name__ == "__main__":
330    _ = CVModule(None, None, None, None, None, None)
class CVModule(lightning.pytorch.core.module.LightningModule):
 13class CVModule(LightningModule):
 14    """`LightningModule` for PAMELA.
 15
 16    A `LightningModule` implements 8 key methods:
 17
 18    ```python
 19    def __init__(self):
 20    # Define initialization code here.
 21
 22    def setup(self, stage):
 23    # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
 24    # This hook is called on every process when using DDP.
 25
 26    def training_step(self, batch, batch_idx):
 27    # The complete training step.
 28
 29    def validation_step(self, batch, batch_idx):
 30    # The complete validation step.
 31
 32    def test_step(self, batch, batch_idx):
 33    # The complete test step.
 34
 35    def predict_step(self, batch, batch_idx):
 36    # The complete predict step.
 37
 38    def configure_optimizers(self):
 39    # Define and configure optimizers and LR schedulers.
 40    ```
 41
 42    Docs:
 43        https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
 44    """
 45
 46    def __init__(
 47        self,
 48        net: torch.nn.Module,
 49        loss: Callable[[Tensor], Tensor],
 50        optimizer: Callable[
 51            [Iterator[torch.nn.Parameter]], torch.optim.Optimizer
 52        ],
 53        scheduler: dict[str, Any] | None,
 54        teacher: Callable[[Tensor], Tensor] | None = None,
 55        terminate_on_nan: bool = True,
 56        compile: bool = False,
 57        **kwargs,
 58    ) -> None:
 59        """Initialize a `PAMELALitModule`.
 60
 61        :param net: The model to train.
 62        :param optimizer: The optimizer to use for training.
 63        :param scheduler: The learning rate scheduler to use for training,
 64            together with the options (how often it should update etc.).
 65        :param val_measures: The explainability specific validation measures to
 66            use.
 67        :param test_measures: " test ".
 68        :param compile: Whether to compile the model before training.
 69        """
 70        super().__init__()
 71
 72        # this line allows to access init params with 'self.hparams' attribute
 73        # also ensures init params will be stored in ckpt
 74        self.save_hyperparameters(
 75            ignore=["net", "loss", "optimizer", "scheduler", "teacher"],
 76            logger=False,
 77        )
 78
 79        self.net = net
 80        self.teacher = None
 81        if teacher is not None:
 82            self.teacher = teacher
 83            loss = self.kl_loss
 84            self.teacher.eval()
 85            self.teacher.requires_grad_(False)
 86
 87        # cut off initialization for simple restore
 88        if loss is None:
 89            return
 90
 91        # loss function
 92        self.criterion = loss
 93
 94        # optimizer and scheduler
 95        self.optimizer_cfg = optimizer
 96        self.scheduler_cfg = scheduler
 97
 98        # metric objects are initialized in `setup` method
 99        self.train_acc = None
100        self.val_acc = None
101        self.test_acc = None
102        # for averaging loss across batches
103        self.train_loss = MeanMetric()
104        self.val_loss = MeanMetric()
105        self.test_loss = MeanMetric()
106        # metric object for calculating and averaging ADIC or ROAD or across batches
107
108        # for tracking best so far validation accuracy
109        self.val_acc_best = MaxMetric()
110
111    @staticmethod
112    def mse_loss(student_logits, teacher_logits):
113        return torch.nn.functional.mse_loss(student_logits, teacher_logits)
114
115    @staticmethod
116    def kl_loss(student_logits, teacher_logits, temperature=5.0):
117        p_t = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)
118        p_s = torch.nn.functional.softmax(student_logits / temperature, dim=-1)
119        return torch.nn.functional.kl_div(
120            p_s.log(), p_t, reduction="batchmean"
121        )
122
123    def convert_fc(self) -> None:
124        """Convert the last layer of the model to a fully connected layer."""
125        if hasattr(self.net, "fc"):
126            in_features = self.net.fc.in_features
127            out_features = (
128                self.trainer.datamodule.num_classes
129                if self.teacher is None
130                else self.teacher.fc.out_features
131            )
132            self.classes = out_features
133            if self.net.fc.out_features == out_features:
134                return
135            self.net.fc = torch.nn.Linear(in_features, out_features)
136            # initialize weights
137            torch.nn.init.xavier_uniform_(self.net.fc.weight)
138            if self.net.fc.bias is not None:
139                torch.nn.init.zeros_(self.net.fc.bias)
140
141    def forward(self, x: torch.Tensor) -> torch.Tensor:
142        """Perform a forward pass through the model `self.net`.
143
144        :param x: A tensor of images.
145        :return: A tensor of logits.
146        """
147        if self.hparams.compile:
148            return self.compiled_net(x)
149        return self.net(x)
150
151    def on_train_start(self) -> None:
152        """Lightning hook that is called when training begins."""
153        # by default lightning executes validation step sanity checks before training starts,
154        # so it's worth to make sure validation metrics don't store results from these checks
155        self.val_loss.reset()
156        self.val_acc.reset()
157        self.val_acc_best.reset()
158
159    def model_step(
160        self, batch: tuple[torch.Tensor, torch.Tensor]
161    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
162        """Perform a single model step on a batch of data.
163
164        :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
165
166        :return: A tuple containing (in order):
167            - A tensor of losses.
168            - A tensor of predictions.
169            - A tensor of target labels.
170        """
171        x, y = batch
172        target = y
173        if self.teacher is not None:
174            if self.hparams.compile:
175                y = self.compiled_teacher(x)
176            else:
177                y = self.teacher(x)
178            target = y.argmax(dim=1)
179        out = self.forward(x)
180        loss = self.criterion(out, y)
181        return loss, out, target
182
183    def training_step(
184        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
185    ) -> torch.Tensor:
186        """Perform a single training step on a batch of data from the training
187        set.
188
189        :param batch: A batch of data (a tuple) containing the input tensor of
190            images and target labels.
191        :param batch_idx: The index of the current batch.
192        :return: A tensor of losses between model predictions and targets.
193        """
194        loss, out, target = self.model_step(batch)
195        preds = out.argmax(dim=1)
196        # update and log metrics
197        self.train_loss(loss)  # compute metric
198        self.log("train/loss", self.train_loss, prog_bar=True)
199        self.train_acc(preds, target)
200        self.log("train/acc", self.train_acc, prog_bar=True)
201
202        # return loss or backpropagation will fail
203        if self.hparams.terminate_on_nan and loss.isnan().any():
204            raise ValueError("NaN detected in loss!")
205        return loss
206
207    def validation_step(
208        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
209    ) -> None:
210        """Perform a single validation step on a batch of data from the
211        validation set.
212
213        :param batch: A batch of data (a tuple) containing the input tensor of
214            images and target labels.
215        :param batch_idx: The index of the current batch.
216        """
217        loss, out, target = self.model_step(batch)
218        preds = out.argmax(dim=1)
219
220        # update and log metrics
221        self.val_loss(loss)
222        self.log("val/loss", self.val_loss)
223        self.val_acc(preds, target)
224        self.log("val/acc", self.val_acc)
225
226    def on_validation_epoch_end(self) -> None:
227        "Lightning hook that is called when a validation epoch ends."
228        acc = self.val_acc.compute()  # get current val acc
229        self.val_acc_best(acc)  # update best so far val acc
230        # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
231        # otherwise metric would be reset by lightning after each epoch
232        self.log(
233            "val/acc_best",
234            self.val_acc_best.compute(),
235            sync_dist=True,
236            prog_bar=True,
237        )
238
239    def test_step(
240        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
241    ) -> None:
242        """Perform a single test step on a batch of data from the test set.
243
244        :param batch: A batch of data (a tuple) containing the input tensor of
245            images and target labels.
246        :param batch_idx: The index of the current batch.
247        """
248        loss, out, target = self.model_step(batch)
249        preds = out.argmax(dim=1)
250
251        # update and log metrics
252        self.test_loss(loss)
253        self.log("test/loss", self.test_loss)
254        self.test_acc(preds, target)
255        self.log("test/acc", self.test_acc)
256
257    def setup(self, stage: str) -> None:
258        """Lightning hook that is called at the beginning of fit (train +
259        validate), validate, test, or predict.
260
261        This is a good hook when you need to build models dynamically or adjust
262        something about them. This hook is called on every process when using
263        DDP.
264
265        :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
266        """
267        self.convert_fc()
268        if self.hparams.compile:
269            self.compiled_net = torch.compile(self.net)
270            if self.teacher is not None:
271                self.compiled_teacher = torch.compile(self.teacher)
272        # metric objects for calculating and averaging accuracy across batches
273        self.train_acc = Accuracy(task="multiclass", num_classes=self.classes)
274        self.val_acc = Accuracy(task="multiclass", num_classes=self.classes)
275        self.test_acc = Accuracy(task="multiclass", num_classes=self.classes)
276
277    def teardown(self, stage: str) -> None:
278        """Lightning hook that is called at the end of fit (train + validate),
279        validate, test, or predict.
280
281        This is a good hook when you need to clean something up after the run.
282
283        :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
284        """
285        if self.hparams.compile:
286            del self.compiled_net
287            if self.teacher:
288                del self.compiled_teacher
289
290    def configure_optimizers(self) -> dict[str, Any]:
291        """Choose what optimizers and learning-rate schedulers to use in your
292        optimization. Normally you'd need one. But in the case of GANs or
293        similar you might have multiple.
294
295        Examples:
296            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
297
298        :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
299        """
300        optimizer = self.optimizer_cfg(params=self.trainer.model.parameters())
301        if (scheduler_info := self.scheduler_cfg) is not None:
302            scheduler = scheduler_info["scheduler"]
303            if scheduler_info["name"] == "OneCycleLR":
304                scheduler = partial(
305                    scheduler,
306                    total_steps=self.trainer.estimated_stepping_batches,
307                )
308            scheduler = scheduler(optimizer=optimizer)
309            scheduler_params = scheduler_info["params"]
310            return {
311                "optimizer": optimizer,
312                "lr_scheduler": {
313                    "scheduler": scheduler,
314                    "interval": scheduler_params["interval"],
315                    "frequency": scheduler_params["frequency"],
316                    "name": scheduler_params["name"],
317                },
318            }
319        return {"optimizer": optimizer}
320
321    def on_save_checkpoint(self, checkpoint):
322        # pop the keys you are not interested by
323        sd = checkpoint["state_dict"]
324        names = list(sd.keys())
325        for name in names:
326            if any(x in name for x in ["_orig_mod", "fc", "teacher"]):
327                sd.pop(name)

LightningModule for PAMELA.

A LightningModule implements 8 key methods:

def __init__(self):
# Define initialization code here.

def setup(self, stage):
# Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
# This hook is called on every process when using DDP.

def training_step(self, batch, batch_idx):
# The complete training step.

def validation_step(self, batch, batch_idx):
# The complete validation step.

def test_step(self, batch, batch_idx):
# The complete test step.

def predict_step(self, batch, batch_idx):
# The complete predict step.

def configure_optimizers(self):
# Define and configure optimizers and LR schedulers.

Docs: https://lightning.ai/docs/pytorch/latest/common/lightning_module.html

CVModule( net: torch.nn.modules.module.Module, loss: Callable[[torch.Tensor], torch.Tensor], optimizer: Callable[[Iterator[torch.nn.parameter.Parameter]], torch.optim.optimizer.Optimizer], scheduler: dict[str, typing.Any] | None, teacher: Callable[[torch.Tensor], torch.Tensor] | None = None, terminate_on_nan: bool = True, compile: bool = False, **kwargs)
 46    def __init__(
 47        self,
 48        net: torch.nn.Module,
 49        loss: Callable[[Tensor], Tensor],
 50        optimizer: Callable[
 51            [Iterator[torch.nn.Parameter]], torch.optim.Optimizer
 52        ],
 53        scheduler: dict[str, Any] | None,
 54        teacher: Callable[[Tensor], Tensor] | None = None,
 55        terminate_on_nan: bool = True,
 56        compile: bool = False,
 57        **kwargs,
 58    ) -> None:
 59        """Initialize a `PAMELALitModule`.
 60
 61        :param net: The model to train.
 62        :param optimizer: The optimizer to use for training.
 63        :param scheduler: The learning rate scheduler to use for training,
 64            together with the options (how often it should update etc.).
 65        :param val_measures: The explainability specific validation measures to
 66            use.
 67        :param test_measures: " test ".
 68        :param compile: Whether to compile the model before training.
 69        """
 70        super().__init__()
 71
 72        # this line allows to access init params with 'self.hparams' attribute
 73        # also ensures init params will be stored in ckpt
 74        self.save_hyperparameters(
 75            ignore=["net", "loss", "optimizer", "scheduler", "teacher"],
 76            logger=False,
 77        )
 78
 79        self.net = net
 80        self.teacher = None
 81        if teacher is not None:
 82            self.teacher = teacher
 83            loss = self.kl_loss
 84            self.teacher.eval()
 85            self.teacher.requires_grad_(False)
 86
 87        # cut off initialization for simple restore
 88        if loss is None:
 89            return
 90
 91        # loss function
 92        self.criterion = loss
 93
 94        # optimizer and scheduler
 95        self.optimizer_cfg = optimizer
 96        self.scheduler_cfg = scheduler
 97
 98        # metric objects are initialized in `setup` method
 99        self.train_acc = None
100        self.val_acc = None
101        self.test_acc = None
102        # for averaging loss across batches
103        self.train_loss = MeanMetric()
104        self.val_loss = MeanMetric()
105        self.test_loss = MeanMetric()
106        # metric object for calculating and averaging ADIC or ROAD or across batches
107
108        # for tracking best so far validation accuracy
109        self.val_acc_best = MaxMetric()

Initialize a PAMELALitModule.

Parameters
  • net: The model to train.
  • optimizer: The optimizer to use for training.
  • scheduler: The learning rate scheduler to use for training, together with the options (how often it should update etc.).
  • val_measures: The explainability specific validation measures to use.
  • test_measures: " test ".
  • compile: Whether to compile the model before training.
net
teacher
criterion
optimizer_cfg
scheduler_cfg
train_acc
val_acc
test_acc
train_loss
val_loss
test_loss
val_acc_best
@staticmethod
def mse_loss(student_logits, teacher_logits):
111    @staticmethod
112    def mse_loss(student_logits, teacher_logits):
113        return torch.nn.functional.mse_loss(student_logits, teacher_logits)
@staticmethod
def kl_loss(student_logits, teacher_logits, temperature=5.0):
115    @staticmethod
116    def kl_loss(student_logits, teacher_logits, temperature=5.0):
117        p_t = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)
118        p_s = torch.nn.functional.softmax(student_logits / temperature, dim=-1)
119        return torch.nn.functional.kl_div(
120            p_s.log(), p_t, reduction="batchmean"
121        )
def convert_fc(self) -> None:
123    def convert_fc(self) -> None:
124        """Convert the last layer of the model to a fully connected layer."""
125        if hasattr(self.net, "fc"):
126            in_features = self.net.fc.in_features
127            out_features = (
128                self.trainer.datamodule.num_classes
129                if self.teacher is None
130                else self.teacher.fc.out_features
131            )
132            self.classes = out_features
133            if self.net.fc.out_features == out_features:
134                return
135            self.net.fc = torch.nn.Linear(in_features, out_features)
136            # initialize weights
137            torch.nn.init.xavier_uniform_(self.net.fc.weight)
138            if self.net.fc.bias is not None:
139                torch.nn.init.zeros_(self.net.fc.bias)

Convert the last layer of the model to a fully connected layer.

def forward(self, x: torch.Tensor) -> torch.Tensor:
141    def forward(self, x: torch.Tensor) -> torch.Tensor:
142        """Perform a forward pass through the model `self.net`.
143
144        :param x: A tensor of images.
145        :return: A tensor of logits.
146        """
147        if self.hparams.compile:
148            return self.compiled_net(x)
149        return self.net(x)

Perform a forward pass through the model self.net.

Parameters
  • x: A tensor of images.
Returns

A tensor of logits.

def on_train_start(self) -> None:
151    def on_train_start(self) -> None:
152        """Lightning hook that is called when training begins."""
153        # by default lightning executes validation step sanity checks before training starts,
154        # so it's worth to make sure validation metrics don't store results from these checks
155        self.val_loss.reset()
156        self.val_acc.reset()
157        self.val_acc_best.reset()

Lightning hook that is called when training begins.

def model_step( self, batch: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
159    def model_step(
160        self, batch: tuple[torch.Tensor, torch.Tensor]
161    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
162        """Perform a single model step on a batch of data.
163
164        :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
165
166        :return: A tuple containing (in order):
167            - A tensor of losses.
168            - A tensor of predictions.
169            - A tensor of target labels.
170        """
171        x, y = batch
172        target = y
173        if self.teacher is not None:
174            if self.hparams.compile:
175                y = self.compiled_teacher(x)
176            else:
177                y = self.teacher(x)
178            target = y.argmax(dim=1)
179        out = self.forward(x)
180        loss = self.criterion(out, y)
181        return loss, out, target

Perform a single model step on a batch of data.

Parameters
  • batch: A batch of data (a tuple) containing the input tensor of images and target labels.
Returns

A tuple containing (in order): - A tensor of losses. - A tensor of predictions. - A tensor of target labels.

def training_step( self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
183    def training_step(
184        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
185    ) -> torch.Tensor:
186        """Perform a single training step on a batch of data from the training
187        set.
188
189        :param batch: A batch of data (a tuple) containing the input tensor of
190            images and target labels.
191        :param batch_idx: The index of the current batch.
192        :return: A tensor of losses between model predictions and targets.
193        """
194        loss, out, target = self.model_step(batch)
195        preds = out.argmax(dim=1)
196        # update and log metrics
197        self.train_loss(loss)  # compute metric
198        self.log("train/loss", self.train_loss, prog_bar=True)
199        self.train_acc(preds, target)
200        self.log("train/acc", self.train_acc, prog_bar=True)
201
202        # return loss or backpropagation will fail
203        if self.hparams.terminate_on_nan and loss.isnan().any():
204            raise ValueError("NaN detected in loss!")
205        return loss

Perform a single training step on a batch of data from the training set.

Parameters
  • batch: A batch of data (a tuple) containing the input tensor of images and target labels.
  • batch_idx: The index of the current batch.
Returns

A tensor of losses between model predictions and targets.

def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
207    def validation_step(
208        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
209    ) -> None:
210        """Perform a single validation step on a batch of data from the
211        validation set.
212
213        :param batch: A batch of data (a tuple) containing the input tensor of
214            images and target labels.
215        :param batch_idx: The index of the current batch.
216        """
217        loss, out, target = self.model_step(batch)
218        preds = out.argmax(dim=1)
219
220        # update and log metrics
221        self.val_loss(loss)
222        self.log("val/loss", self.val_loss)
223        self.val_acc(preds, target)
224        self.log("val/acc", self.val_acc)

Perform a single validation step on a batch of data from the validation set.

Parameters
  • batch: A batch of data (a tuple) containing the input tensor of images and target labels.
  • batch_idx: The index of the current batch.
def on_validation_epoch_end(self) -> None:
226    def on_validation_epoch_end(self) -> None:
227        "Lightning hook that is called when a validation epoch ends."
228        acc = self.val_acc.compute()  # get current val acc
229        self.val_acc_best(acc)  # update best so far val acc
230        # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
231        # otherwise metric would be reset by lightning after each epoch
232        self.log(
233            "val/acc_best",
234            self.val_acc_best.compute(),
235            sync_dist=True,
236            prog_bar=True,
237        )

Lightning hook that is called when a validation epoch ends.

def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
239    def test_step(
240        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
241    ) -> None:
242        """Perform a single test step on a batch of data from the test set.
243
244        :param batch: A batch of data (a tuple) containing the input tensor of
245            images and target labels.
246        :param batch_idx: The index of the current batch.
247        """
248        loss, out, target = self.model_step(batch)
249        preds = out.argmax(dim=1)
250
251        # update and log metrics
252        self.test_loss(loss)
253        self.log("test/loss", self.test_loss)
254        self.test_acc(preds, target)
255        self.log("test/acc", self.test_acc)

Perform a single test step on a batch of data from the test set.

Parameters
  • batch: A batch of data (a tuple) containing the input tensor of images and target labels.
  • batch_idx: The index of the current batch.
def setup(self, stage: str) -> None:
257    def setup(self, stage: str) -> None:
258        """Lightning hook that is called at the beginning of fit (train +
259        validate), validate, test, or predict.
260
261        This is a good hook when you need to build models dynamically or adjust
262        something about them. This hook is called on every process when using
263        DDP.
264
265        :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
266        """
267        self.convert_fc()
268        if self.hparams.compile:
269            self.compiled_net = torch.compile(self.net)
270            if self.teacher is not None:
271                self.compiled_teacher = torch.compile(self.teacher)
272        # metric objects for calculating and averaging accuracy across batches
273        self.train_acc = Accuracy(task="multiclass", num_classes=self.classes)
274        self.val_acc = Accuracy(task="multiclass", num_classes=self.classes)
275        self.test_acc = Accuracy(task="multiclass", num_classes=self.classes)

Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.

This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters
  • stage: Either "fit", "validate", "test", or "predict".
def teardown(self, stage: str) -> None:
277    def teardown(self, stage: str) -> None:
278        """Lightning hook that is called at the end of fit (train + validate),
279        validate, test, or predict.
280
281        This is a good hook when you need to clean something up after the run.
282
283        :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
284        """
285        if self.hparams.compile:
286            del self.compiled_net
287            if self.teacher:
288                del self.compiled_teacher

Lightning hook that is called at the end of fit (train + validate), validate, test, or predict.

This is a good hook when you need to clean something up after the run.

Parameters
  • stage: Either "fit", "validate", "test", or "predict".
def configure_optimizers(self) -> dict[str, typing.Any]:
290    def configure_optimizers(self) -> dict[str, Any]:
291        """Choose what optimizers and learning-rate schedulers to use in your
292        optimization. Normally you'd need one. But in the case of GANs or
293        similar you might have multiple.
294
295        Examples:
296            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
297
298        :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
299        """
300        optimizer = self.optimizer_cfg(params=self.trainer.model.parameters())
301        if (scheduler_info := self.scheduler_cfg) is not None:
302            scheduler = scheduler_info["scheduler"]
303            if scheduler_info["name"] == "OneCycleLR":
304                scheduler = partial(
305                    scheduler,
306                    total_steps=self.trainer.estimated_stepping_batches,
307                )
308            scheduler = scheduler(optimizer=optimizer)
309            scheduler_params = scheduler_info["params"]
310            return {
311                "optimizer": optimizer,
312                "lr_scheduler": {
313                    "scheduler": scheduler,
314                    "interval": scheduler_params["interval"],
315                    "frequency": scheduler_params["frequency"],
316                    "name": scheduler_params["name"],
317                },
318            }
319        return {"optimizer": optimizer}

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple.

Examples: https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers

Returns

A dict containing the configured optimizers and learning-rate schedulers to be used for training.

def on_save_checkpoint(self, checkpoint):
321    def on_save_checkpoint(self, checkpoint):
322        # pop the keys you are not interested by
323        sd = checkpoint["state_dict"]
324        names = list(sd.keys())
325        for name in names:
326            if any(x in name for x in ["_orig_mod", "fc", "teacher"]):
327                sd.pop(name)

Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.

Args: checkpoint: The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.

Example::

def on_save_checkpoint(self, checkpoint):
    # 99% of use cases you don't need to implement this method
    checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object

Note: Lightning saves all aspects of training (epoch, global step, etc...) including amp scaling. There is no need for you to store anything about training.