ptame.models.ptame_module

  1from collections.abc import Callable, Iterator
  2from functools import partial
  3from pathlib import Path
  4from typing import Any
  5
  6import torch
  7from lightning import LightningModule
  8from torch import nn
  9from torchmetrics import MaxMetric, MeanMetric, Metric
 10from torchmetrics.classification.accuracy import Accuracy
 11
 12from ptame.utils.measures import Composer
 13
 14from .components.loss import Loss
 15
 16
 17class PTAMELitModule(LightningModule):
 18    """`LightningModule` for PTAME.
 19
 20    A `LightningModule` implements 8 key methods:
 21
 22    ```python
 23    def __init__(self):
 24    # Define initialization code here.
 25
 26    def setup(self, stage):
 27    # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
 28    # This hook is called on every process when using DDP.
 29
 30    def training_step(self, batch, batch_idx):
 31    # The complete training step.
 32
 33    def validation_step(self, batch, batch_idx):
 34    # The complete validation step.
 35
 36    def test_step(self, batch, batch_idx):
 37    # The complete test step.
 38
 39    def predict_step(self, batch, batch_idx):
 40    # The complete predict step.
 41
 42    def configure_optimizers(self):
 43    # Define and configure optimizers and LR schedulers.
 44    ```
 45
 46    Docs:
 47        https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
 48    """
 49
 50    def __init__(
 51        self,
 52        net: torch.nn.Module,
 53        loss: Loss,
 54        optimizer: Callable[
 55            [Iterator[torch.nn.Parameter]], torch.optim.Optimizer
 56        ],
 57        scheduler: dict[str, Any] | None,
 58        val_measures: dict[str, Metric] | None,
 59        test_measures: dict[str, Metric] | None,
 60        feature_contribution: bool = False,
 61        terminate_on_nan: bool = True,
 62        compile: bool = False,
 63        **kwargs,
 64    ) -> None:
 65        """Initialize a `PTAMELitModule`.
 66
 67        :param net: The model to train.
 68        :param optimizer: The optimizer to use for training.
 69        :param scheduler: The learning rate scheduler to use for training,
 70            together with the options (how often it should update etc.).
 71        :param val_measures: The explainability specific validation measures to
 72            use.
 73        :param test_measures: " test ".
 74        :param compile: Whether to compile the model before training.
 75        """
 76        super().__init__()
 77
 78        # this line allows to access init params with 'self.hparams' attribute
 79        # also ensures init params will be stored in ckpt
 80        self.save_hyperparameters(
 81            ignore=[
 82                "net",
 83                "loss",
 84                "val_measures",
 85                "test_measures",
 86                "optimizer",
 87                "scheduler",
 88            ],
 89            logger=False,
 90        )
 91
 92        self.net = net
 93
 94        # cut off initialization for simple restore
 95        if loss is None:
 96            return
 97
 98        # loss function
 99        self.criterion = loss
100
101        # optimizer and scheduler
102        self.optimizer_cfg = optimizer
103        self.scheduler_cfg = scheduler
104
105        # metric objects are initialized in `setup` method
106        self.train_acc = None
107        self.val_acc = None
108        self.test_acc = None
109        # for averaging loss across batches
110        self.train_losses = nn.ModuleList(
111            [MeanMetric() for _ in range(self.criterion.num_terms)]
112        )
113        self.val_losses = nn.ModuleList(
114            [MeanMetric() for _ in range(self.criterion.num_terms)]
115        )
116        self.test_loss = MeanMetric()
117        # metric object for calculating and averaging ADIC or ROAD or across batches
118        self.val_measures = (
119            Composer(nn.ModuleList(val_measures.values()), prefix="val/")
120            if val_measures
121            else None
122        )
123        self.test_measures = (
124            Composer(nn.ModuleList(test_measures.values()), prefix="test/")
125            if test_measures
126            else None
127        )
128
129        # for tracking best so far validation accuracy
130        self.val_acc_best = MaxMetric()
131
132    def forward(self, x: torch.Tensor) -> torch.Tensor:
133        """Perform a forward pass through the model `self.net`.
134
135        :param x: A tensor of images.
136        :return: A tensor of logits.
137        """
138        if self.hparams.compile:
139            return self.compiled_net(x)
140        return self.net(x)
141
142    def on_train_start(self) -> None:
143        """Lightning hook that is called when training begins."""
144        # by default lightning executes validation step sanity checks before training starts,
145        # so it's worth to make sure validation metrics don't store results from these checks
146        [val_loss.reset() for val_loss in self.val_losses]
147        self.val_acc.reset()
148        self.val_acc_best.reset()
149
150    def model_step(
151        self, batch: tuple[torch.Tensor, torch.Tensor]
152    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
153        """Perform a single model step on a batch of data.
154
155        :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
156
157        :return: A tuple containing (in order):
158            - A tensor of losses.
159            - A tensor of predictions.
160            - A tensor of target labels.
161        """
162        x, y_ground = batch
163        out = self.forward(x)
164        losses = self.criterion(**out, epoch=self.trainer.current_epoch)
165        return losses, out
166
167    def training_step(
168        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
169    ) -> torch.Tensor:
170        """Perform a single training step on a batch of data from the training
171        set.
172
173        :param batch: A batch of data (a tuple) containing the input tensor of
174            images and target labels.
175        :param batch_idx: The index of the current batch.
176        :return: A tensor of losses between model predictions and targets.
177        """
178        losses, out = self.model_step(batch)
179        preds_masked = out["logits_masked"]
180        targets = out["targets"]
181        # update and log metrics
182        for i, metric in enumerate(self.train_losses):
183            metric(losses[i])  # compute metric
184            self.log(f"train/loss[{i}]", metric, prog_bar=True)
185        self.train_acc(preds_masked, targets)
186        self.log("train/acc", self.train_acc, prog_bar=True)
187
188        # return loss or backpropagation will fail
189        if self.hparams.terminate_on_nan and losses[0].isnan().any():
190            raise ValueError("NaN detected in loss!")
191        return losses[0]
192
193    def on_validation_epoch_start(self) -> None:
194        """Lightning hook that is called before a validation epoch begins."""
195        if self.val_measures:
196            if self.hparams.compile:
197                self.val_measures.register_net(self.compiled_net)
198            else:
199                self.val_measures.register_net(self.net)
200
201    def validation_step(
202        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
203    ) -> None:
204        """Perform a single validation step on a batch of data from the
205        validation set.
206
207        :param batch: A batch of data (a tuple) containing the input tensor of
208            images and target labels.
209        :param batch_idx: The index of the current batch.
210        """
211        losses, out = self.model_step(batch)
212        preds = out["logits"]
213        preds_masked = out["logits_masked"]
214        targets = out["targets"]
215        maps = out["masks"]
216
217        # update and log metrics
218        for i, metric in enumerate(self.val_losses):
219            metric.update(losses[i])
220            self.log(f"val/loss[{i}]", metric)
221        self.val_acc.update(preds_masked, targets)
222        self.log("val/acc", self.val_acc)
223        if self.val_measures:
224            self.val_measures.update(batch[0], preds, targets, maps)
225
226    def on_validation_epoch_end(self) -> None:
227        "Lightning hook that is called when a validation epoch ends."
228        acc = self.val_acc.compute()  # get current val acc
229        self.val_acc_best(acc)  # update best so far val acc
230        # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
231        # otherwise metric would be reset by lightning after each epoch
232        self.log(
233            "val/acc_best",
234            self.val_acc_best.compute(),
235            sync_dist=True,
236            prog_bar=True,
237        )
238        if self.val_measures:
239            self.log_dict(self.val_measures.compute())
240            self.val_measures.reset()
241
242    def on_test_epoch_start(self) -> None:
243        """Lightning hook that is called before a test epoch begins."""
244        if self.test_measures:
245            if self.hparams.compile:
246                self.test_measures.register_net(self.compiled_net)
247            else:
248                self.test_measures.register_net(self.net)
249
250    def test_step(
251        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
252    ) -> None:
253        """Perform a single test step on a batch of data from the test set.
254
255        :param batch: A batch of data (a tuple) containing the input tensor of
256            images and target labels.
257        :param batch_idx: The index of the current batch.
258        """
259        losses, out = self.model_step(batch)
260        preds = out["logits"]
261        preds_masked = out["logits_masked"]
262        targets = out["targets"]
263        maps = out["masks"]
264
265        # update and log metrics
266        self.test_loss.update(losses[0])
267        self.log("test/loss", self.test_loss)
268        self.test_acc.update(preds_masked, targets)
269        self.log("test/acc", self.test_acc)
270        if save_masks := self.hparams.get("save_masks"):
271            Path(save_masks).mkdir(parents=True, exist_ok=True)
272            torch.save(maps, f"{save_masks}/{batch_idx}.pt")
273        if self.test_measures:
274            self.test_measures.update(batch[0], preds, targets, maps)
275
276    def on_test_epoch_end(self) -> None:
277        """Lightning hook that is called when a test epoch ends."""
278        if self.hparams.feature_contribution:
279            layer_names, contribs = self.net.attention.get_contributions()
280            contribs_dict = {
281                layer: contrib
282                for layer, contrib in zip(layer_names, contribs.mean(dim=0))
283            }
284            self.log_dict(contribs_dict)
285
286        if self.test_measures:
287            self.log_dict(self.test_measures.compute())
288            self.test_measures.reset()
289
290    def setup(self, stage: str) -> None:
291        """Lightning hook that is called at the beginning of fit (train +
292        validate), validate, test, or predict.
293
294        This is a good hook when you need to build models dynamically or adjust
295        something about them. This hook is called on every process when using
296        DDP.
297
298        :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
299        """
300
301        if self.hparams.compile:
302            self.compiled_net = torch.compile(self.net)
303        # metric objects for calculating and averaging accuracy across batches
304        self.train_acc = Accuracy(
305            task="multiclass", num_classes=self.trainer.datamodule.num_classes
306        )
307        self.val_acc = Accuracy(
308            task="multiclass", num_classes=self.trainer.datamodule.num_classes
309        )
310        self.test_acc = Accuracy(
311            task="multiclass", num_classes=self.trainer.datamodule.num_classes
312        )
313
314    def teardown(self, stage: str) -> None:
315        """Lightning hook that is called at the end of fit (train + validate),
316        validate, test, or predict.
317
318        This is a good hook when you need to clean something up after the run.
319
320        :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
321        """
322        if self.hparams.compile:
323            del self.compiled_net
324
325    def configure_optimizers(self) -> dict[str, Any]:
326        """Choose what optimizers and learning-rate schedulers to use in your
327        optimization. Normally you'd need one. But in the case of GANs or
328        similar you might have multiple.
329
330        Examples:
331            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
332
333        :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
334        """
335        optimizer = self.optimizer_cfg(params=self.trainer.model.parameters())
336        if (scheduler_info := self.scheduler_cfg) is not None:
337            scheduler = scheduler_info["scheduler"]
338            if scheduler_info["name"] == "OneCycleLR":
339                scheduler = partial(
340                    scheduler,
341                    total_steps=self.trainer.estimated_stepping_batches,
342                )
343            scheduler = scheduler(optimizer=optimizer)
344            scheduler_params = scheduler_info["params"]
345            return {
346                "optimizer": optimizer,
347                "lr_scheduler": {
348                    "scheduler": scheduler,
349                    "interval": scheduler_params["interval"],
350                    "frequency": scheduler_params["frequency"],
351                    "name": scheduler_params["name"],
352                },
353            }
354        return {"optimizer": optimizer}
355
356    def on_save_checkpoint(self, checkpoint):
357        # pop the keys you are not interested by
358        sd = checkpoint["state_dict"]
359        names = list(sd.keys())
360        for name in names:
361            if "backbone" in name or "_orig_mod" in name:
362                sd.pop(name)
363            if "net.attention.attention" in name:
364                sd.pop(name)
365
366
367if __name__ == "__main__":
368    _ = PTAMELitModule(None, None, None, None, None, None)
class PTAMELitModule(lightning.pytorch.core.module.LightningModule):
 18class PTAMELitModule(LightningModule):
 19    """`LightningModule` for PTAME.
 20
 21    A `LightningModule` implements 8 key methods:
 22
 23    ```python
 24    def __init__(self):
 25    # Define initialization code here.
 26
 27    def setup(self, stage):
 28    # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
 29    # This hook is called on every process when using DDP.
 30
 31    def training_step(self, batch, batch_idx):
 32    # The complete training step.
 33
 34    def validation_step(self, batch, batch_idx):
 35    # The complete validation step.
 36
 37    def test_step(self, batch, batch_idx):
 38    # The complete test step.
 39
 40    def predict_step(self, batch, batch_idx):
 41    # The complete predict step.
 42
 43    def configure_optimizers(self):
 44    # Define and configure optimizers and LR schedulers.
 45    ```
 46
 47    Docs:
 48        https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
 49    """
 50
 51    def __init__(
 52        self,
 53        net: torch.nn.Module,
 54        loss: Loss,
 55        optimizer: Callable[
 56            [Iterator[torch.nn.Parameter]], torch.optim.Optimizer
 57        ],
 58        scheduler: dict[str, Any] | None,
 59        val_measures: dict[str, Metric] | None,
 60        test_measures: dict[str, Metric] | None,
 61        feature_contribution: bool = False,
 62        terminate_on_nan: bool = True,
 63        compile: bool = False,
 64        **kwargs,
 65    ) -> None:
 66        """Initialize a `PTAMELitModule`.
 67
 68        :param net: The model to train.
 69        :param optimizer: The optimizer to use for training.
 70        :param scheduler: The learning rate scheduler to use for training,
 71            together with the options (how often it should update etc.).
 72        :param val_measures: The explainability specific validation measures to
 73            use.
 74        :param test_measures: " test ".
 75        :param compile: Whether to compile the model before training.
 76        """
 77        super().__init__()
 78
 79        # this line allows to access init params with 'self.hparams' attribute
 80        # also ensures init params will be stored in ckpt
 81        self.save_hyperparameters(
 82            ignore=[
 83                "net",
 84                "loss",
 85                "val_measures",
 86                "test_measures",
 87                "optimizer",
 88                "scheduler",
 89            ],
 90            logger=False,
 91        )
 92
 93        self.net = net
 94
 95        # cut off initialization for simple restore
 96        if loss is None:
 97            return
 98
 99        # loss function
100        self.criterion = loss
101
102        # optimizer and scheduler
103        self.optimizer_cfg = optimizer
104        self.scheduler_cfg = scheduler
105
106        # metric objects are initialized in `setup` method
107        self.train_acc = None
108        self.val_acc = None
109        self.test_acc = None
110        # for averaging loss across batches
111        self.train_losses = nn.ModuleList(
112            [MeanMetric() for _ in range(self.criterion.num_terms)]
113        )
114        self.val_losses = nn.ModuleList(
115            [MeanMetric() for _ in range(self.criterion.num_terms)]
116        )
117        self.test_loss = MeanMetric()
118        # metric object for calculating and averaging ADIC or ROAD or across batches
119        self.val_measures = (
120            Composer(nn.ModuleList(val_measures.values()), prefix="val/")
121            if val_measures
122            else None
123        )
124        self.test_measures = (
125            Composer(nn.ModuleList(test_measures.values()), prefix="test/")
126            if test_measures
127            else None
128        )
129
130        # for tracking best so far validation accuracy
131        self.val_acc_best = MaxMetric()
132
133    def forward(self, x: torch.Tensor) -> torch.Tensor:
134        """Perform a forward pass through the model `self.net`.
135
136        :param x: A tensor of images.
137        :return: A tensor of logits.
138        """
139        if self.hparams.compile:
140            return self.compiled_net(x)
141        return self.net(x)
142
143    def on_train_start(self) -> None:
144        """Lightning hook that is called when training begins."""
145        # by default lightning executes validation step sanity checks before training starts,
146        # so it's worth to make sure validation metrics don't store results from these checks
147        [val_loss.reset() for val_loss in self.val_losses]
148        self.val_acc.reset()
149        self.val_acc_best.reset()
150
151    def model_step(
152        self, batch: tuple[torch.Tensor, torch.Tensor]
153    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
154        """Perform a single model step on a batch of data.
155
156        :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
157
158        :return: A tuple containing (in order):
159            - A tensor of losses.
160            - A tensor of predictions.
161            - A tensor of target labels.
162        """
163        x, y_ground = batch
164        out = self.forward(x)
165        losses = self.criterion(**out, epoch=self.trainer.current_epoch)
166        return losses, out
167
168    def training_step(
169        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
170    ) -> torch.Tensor:
171        """Perform a single training step on a batch of data from the training
172        set.
173
174        :param batch: A batch of data (a tuple) containing the input tensor of
175            images and target labels.
176        :param batch_idx: The index of the current batch.
177        :return: A tensor of losses between model predictions and targets.
178        """
179        losses, out = self.model_step(batch)
180        preds_masked = out["logits_masked"]
181        targets = out["targets"]
182        # update and log metrics
183        for i, metric in enumerate(self.train_losses):
184            metric(losses[i])  # compute metric
185            self.log(f"train/loss[{i}]", metric, prog_bar=True)
186        self.train_acc(preds_masked, targets)
187        self.log("train/acc", self.train_acc, prog_bar=True)
188
189        # return loss or backpropagation will fail
190        if self.hparams.terminate_on_nan and losses[0].isnan().any():
191            raise ValueError("NaN detected in loss!")
192        return losses[0]
193
194    def on_validation_epoch_start(self) -> None:
195        """Lightning hook that is called before a validation epoch begins."""
196        if self.val_measures:
197            if self.hparams.compile:
198                self.val_measures.register_net(self.compiled_net)
199            else:
200                self.val_measures.register_net(self.net)
201
202    def validation_step(
203        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
204    ) -> None:
205        """Perform a single validation step on a batch of data from the
206        validation set.
207
208        :param batch: A batch of data (a tuple) containing the input tensor of
209            images and target labels.
210        :param batch_idx: The index of the current batch.
211        """
212        losses, out = self.model_step(batch)
213        preds = out["logits"]
214        preds_masked = out["logits_masked"]
215        targets = out["targets"]
216        maps = out["masks"]
217
218        # update and log metrics
219        for i, metric in enumerate(self.val_losses):
220            metric.update(losses[i])
221            self.log(f"val/loss[{i}]", metric)
222        self.val_acc.update(preds_masked, targets)
223        self.log("val/acc", self.val_acc)
224        if self.val_measures:
225            self.val_measures.update(batch[0], preds, targets, maps)
226
227    def on_validation_epoch_end(self) -> None:
228        "Lightning hook that is called when a validation epoch ends."
229        acc = self.val_acc.compute()  # get current val acc
230        self.val_acc_best(acc)  # update best so far val acc
231        # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
232        # otherwise metric would be reset by lightning after each epoch
233        self.log(
234            "val/acc_best",
235            self.val_acc_best.compute(),
236            sync_dist=True,
237            prog_bar=True,
238        )
239        if self.val_measures:
240            self.log_dict(self.val_measures.compute())
241            self.val_measures.reset()
242
243    def on_test_epoch_start(self) -> None:
244        """Lightning hook that is called before a test epoch begins."""
245        if self.test_measures:
246            if self.hparams.compile:
247                self.test_measures.register_net(self.compiled_net)
248            else:
249                self.test_measures.register_net(self.net)
250
251    def test_step(
252        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
253    ) -> None:
254        """Perform a single test step on a batch of data from the test set.
255
256        :param batch: A batch of data (a tuple) containing the input tensor of
257            images and target labels.
258        :param batch_idx: The index of the current batch.
259        """
260        losses, out = self.model_step(batch)
261        preds = out["logits"]
262        preds_masked = out["logits_masked"]
263        targets = out["targets"]
264        maps = out["masks"]
265
266        # update and log metrics
267        self.test_loss.update(losses[0])
268        self.log("test/loss", self.test_loss)
269        self.test_acc.update(preds_masked, targets)
270        self.log("test/acc", self.test_acc)
271        if save_masks := self.hparams.get("save_masks"):
272            Path(save_masks).mkdir(parents=True, exist_ok=True)
273            torch.save(maps, f"{save_masks}/{batch_idx}.pt")
274        if self.test_measures:
275            self.test_measures.update(batch[0], preds, targets, maps)
276
277    def on_test_epoch_end(self) -> None:
278        """Lightning hook that is called when a test epoch ends."""
279        if self.hparams.feature_contribution:
280            layer_names, contribs = self.net.attention.get_contributions()
281            contribs_dict = {
282                layer: contrib
283                for layer, contrib in zip(layer_names, contribs.mean(dim=0))
284            }
285            self.log_dict(contribs_dict)
286
287        if self.test_measures:
288            self.log_dict(self.test_measures.compute())
289            self.test_measures.reset()
290
291    def setup(self, stage: str) -> None:
292        """Lightning hook that is called at the beginning of fit (train +
293        validate), validate, test, or predict.
294
295        This is a good hook when you need to build models dynamically or adjust
296        something about them. This hook is called on every process when using
297        DDP.
298
299        :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
300        """
301
302        if self.hparams.compile:
303            self.compiled_net = torch.compile(self.net)
304        # metric objects for calculating and averaging accuracy across batches
305        self.train_acc = Accuracy(
306            task="multiclass", num_classes=self.trainer.datamodule.num_classes
307        )
308        self.val_acc = Accuracy(
309            task="multiclass", num_classes=self.trainer.datamodule.num_classes
310        )
311        self.test_acc = Accuracy(
312            task="multiclass", num_classes=self.trainer.datamodule.num_classes
313        )
314
315    def teardown(self, stage: str) -> None:
316        """Lightning hook that is called at the end of fit (train + validate),
317        validate, test, or predict.
318
319        This is a good hook when you need to clean something up after the run.
320
321        :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
322        """
323        if self.hparams.compile:
324            del self.compiled_net
325
326    def configure_optimizers(self) -> dict[str, Any]:
327        """Choose what optimizers and learning-rate schedulers to use in your
328        optimization. Normally you'd need one. But in the case of GANs or
329        similar you might have multiple.
330
331        Examples:
332            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
333
334        :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
335        """
336        optimizer = self.optimizer_cfg(params=self.trainer.model.parameters())
337        if (scheduler_info := self.scheduler_cfg) is not None:
338            scheduler = scheduler_info["scheduler"]
339            if scheduler_info["name"] == "OneCycleLR":
340                scheduler = partial(
341                    scheduler,
342                    total_steps=self.trainer.estimated_stepping_batches,
343                )
344            scheduler = scheduler(optimizer=optimizer)
345            scheduler_params = scheduler_info["params"]
346            return {
347                "optimizer": optimizer,
348                "lr_scheduler": {
349                    "scheduler": scheduler,
350                    "interval": scheduler_params["interval"],
351                    "frequency": scheduler_params["frequency"],
352                    "name": scheduler_params["name"],
353                },
354            }
355        return {"optimizer": optimizer}
356
357    def on_save_checkpoint(self, checkpoint):
358        # pop the keys you are not interested by
359        sd = checkpoint["state_dict"]
360        names = list(sd.keys())
361        for name in names:
362            if "backbone" in name or "_orig_mod" in name:
363                sd.pop(name)
364            if "net.attention.attention" in name:
365                sd.pop(name)

LightningModule for PTAME.

A LightningModule implements 8 key methods:

def __init__(self):
# Define initialization code here.

def setup(self, stage):
# Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
# This hook is called on every process when using DDP.

def training_step(self, batch, batch_idx):
# The complete training step.

def validation_step(self, batch, batch_idx):
# The complete validation step.

def test_step(self, batch, batch_idx):
# The complete test step.

def predict_step(self, batch, batch_idx):
# The complete predict step.

def configure_optimizers(self):
# Define and configure optimizers and LR schedulers.

Docs: https://lightning.ai/docs/pytorch/latest/common/lightning_module.html

PTAMELitModule( net: torch.nn.modules.module.Module, loss: ptame.models.components.loss.Loss, optimizer: Callable[[Iterator[torch.nn.parameter.Parameter]], torch.optim.optimizer.Optimizer], scheduler: dict[str, typing.Any] | None, val_measures: dict[str, torchmetrics.metric.Metric] | None, test_measures: dict[str, torchmetrics.metric.Metric] | None, feature_contribution: bool = False, terminate_on_nan: bool = True, compile: bool = False, **kwargs)
 51    def __init__(
 52        self,
 53        net: torch.nn.Module,
 54        loss: Loss,
 55        optimizer: Callable[
 56            [Iterator[torch.nn.Parameter]], torch.optim.Optimizer
 57        ],
 58        scheduler: dict[str, Any] | None,
 59        val_measures: dict[str, Metric] | None,
 60        test_measures: dict[str, Metric] | None,
 61        feature_contribution: bool = False,
 62        terminate_on_nan: bool = True,
 63        compile: bool = False,
 64        **kwargs,
 65    ) -> None:
 66        """Initialize a `PTAMELitModule`.
 67
 68        :param net: The model to train.
 69        :param optimizer: The optimizer to use for training.
 70        :param scheduler: The learning rate scheduler to use for training,
 71            together with the options (how often it should update etc.).
 72        :param val_measures: The explainability specific validation measures to
 73            use.
 74        :param test_measures: " test ".
 75        :param compile: Whether to compile the model before training.
 76        """
 77        super().__init__()
 78
 79        # this line allows to access init params with 'self.hparams' attribute
 80        # also ensures init params will be stored in ckpt
 81        self.save_hyperparameters(
 82            ignore=[
 83                "net",
 84                "loss",
 85                "val_measures",
 86                "test_measures",
 87                "optimizer",
 88                "scheduler",
 89            ],
 90            logger=False,
 91        )
 92
 93        self.net = net
 94
 95        # cut off initialization for simple restore
 96        if loss is None:
 97            return
 98
 99        # loss function
100        self.criterion = loss
101
102        # optimizer and scheduler
103        self.optimizer_cfg = optimizer
104        self.scheduler_cfg = scheduler
105
106        # metric objects are initialized in `setup` method
107        self.train_acc = None
108        self.val_acc = None
109        self.test_acc = None
110        # for averaging loss across batches
111        self.train_losses = nn.ModuleList(
112            [MeanMetric() for _ in range(self.criterion.num_terms)]
113        )
114        self.val_losses = nn.ModuleList(
115            [MeanMetric() for _ in range(self.criterion.num_terms)]
116        )
117        self.test_loss = MeanMetric()
118        # metric object for calculating and averaging ADIC or ROAD or across batches
119        self.val_measures = (
120            Composer(nn.ModuleList(val_measures.values()), prefix="val/")
121            if val_measures
122            else None
123        )
124        self.test_measures = (
125            Composer(nn.ModuleList(test_measures.values()), prefix="test/")
126            if test_measures
127            else None
128        )
129
130        # for tracking best so far validation accuracy
131        self.val_acc_best = MaxMetric()

Initialize a PTAMELitModule.

Parameters
  • net: The model to train.
  • optimizer: The optimizer to use for training.
  • scheduler: The learning rate scheduler to use for training, together with the options (how often it should update etc.).
  • val_measures: The explainability specific validation measures to use.
  • test_measures: " test ".
  • compile: Whether to compile the model before training.
net
criterion
optimizer_cfg
scheduler_cfg
train_acc
val_acc
test_acc
train_losses
val_losses
test_loss
val_measures
test_measures
val_acc_best
def forward(self, x: torch.Tensor) -> torch.Tensor:
133    def forward(self, x: torch.Tensor) -> torch.Tensor:
134        """Perform a forward pass through the model `self.net`.
135
136        :param x: A tensor of images.
137        :return: A tensor of logits.
138        """
139        if self.hparams.compile:
140            return self.compiled_net(x)
141        return self.net(x)

Perform a forward pass through the model self.net.

Parameters
  • x: A tensor of images.
Returns

A tensor of logits.

def on_train_start(self) -> None:
143    def on_train_start(self) -> None:
144        """Lightning hook that is called when training begins."""
145        # by default lightning executes validation step sanity checks before training starts,
146        # so it's worth to make sure validation metrics don't store results from these checks
147        [val_loss.reset() for val_loss in self.val_losses]
148        self.val_acc.reset()
149        self.val_acc_best.reset()

Lightning hook that is called when training begins.

def model_step( self, batch: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
151    def model_step(
152        self, batch: tuple[torch.Tensor, torch.Tensor]
153    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
154        """Perform a single model step on a batch of data.
155
156        :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
157
158        :return: A tuple containing (in order):
159            - A tensor of losses.
160            - A tensor of predictions.
161            - A tensor of target labels.
162        """
163        x, y_ground = batch
164        out = self.forward(x)
165        losses = self.criterion(**out, epoch=self.trainer.current_epoch)
166        return losses, out

Perform a single model step on a batch of data.

Parameters
  • batch: A batch of data (a tuple) containing the input tensor of images and target labels.
Returns

A tuple containing (in order): - A tensor of losses. - A tensor of predictions. - A tensor of target labels.

def training_step( self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
168    def training_step(
169        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
170    ) -> torch.Tensor:
171        """Perform a single training step on a batch of data from the training
172        set.
173
174        :param batch: A batch of data (a tuple) containing the input tensor of
175            images and target labels.
176        :param batch_idx: The index of the current batch.
177        :return: A tensor of losses between model predictions and targets.
178        """
179        losses, out = self.model_step(batch)
180        preds_masked = out["logits_masked"]
181        targets = out["targets"]
182        # update and log metrics
183        for i, metric in enumerate(self.train_losses):
184            metric(losses[i])  # compute metric
185            self.log(f"train/loss[{i}]", metric, prog_bar=True)
186        self.train_acc(preds_masked, targets)
187        self.log("train/acc", self.train_acc, prog_bar=True)
188
189        # return loss or backpropagation will fail
190        if self.hparams.terminate_on_nan and losses[0].isnan().any():
191            raise ValueError("NaN detected in loss!")
192        return losses[0]

Perform a single training step on a batch of data from the training set.

Parameters
  • batch: A batch of data (a tuple) containing the input tensor of images and target labels.
  • batch_idx: The index of the current batch.
Returns

A tensor of losses between model predictions and targets.

def on_validation_epoch_start(self) -> None:
194    def on_validation_epoch_start(self) -> None:
195        """Lightning hook that is called before a validation epoch begins."""
196        if self.val_measures:
197            if self.hparams.compile:
198                self.val_measures.register_net(self.compiled_net)
199            else:
200                self.val_measures.register_net(self.net)

Lightning hook that is called before a validation epoch begins.

def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
202    def validation_step(
203        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
204    ) -> None:
205        """Perform a single validation step on a batch of data from the
206        validation set.
207
208        :param batch: A batch of data (a tuple) containing the input tensor of
209            images and target labels.
210        :param batch_idx: The index of the current batch.
211        """
212        losses, out = self.model_step(batch)
213        preds = out["logits"]
214        preds_masked = out["logits_masked"]
215        targets = out["targets"]
216        maps = out["masks"]
217
218        # update and log metrics
219        for i, metric in enumerate(self.val_losses):
220            metric.update(losses[i])
221            self.log(f"val/loss[{i}]", metric)
222        self.val_acc.update(preds_masked, targets)
223        self.log("val/acc", self.val_acc)
224        if self.val_measures:
225            self.val_measures.update(batch[0], preds, targets, maps)

Perform a single validation step on a batch of data from the validation set.

Parameters
  • batch: A batch of data (a tuple) containing the input tensor of images and target labels.
  • batch_idx: The index of the current batch.
def on_validation_epoch_end(self) -> None:
227    def on_validation_epoch_end(self) -> None:
228        "Lightning hook that is called when a validation epoch ends."
229        acc = self.val_acc.compute()  # get current val acc
230        self.val_acc_best(acc)  # update best so far val acc
231        # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
232        # otherwise metric would be reset by lightning after each epoch
233        self.log(
234            "val/acc_best",
235            self.val_acc_best.compute(),
236            sync_dist=True,
237            prog_bar=True,
238        )
239        if self.val_measures:
240            self.log_dict(self.val_measures.compute())
241            self.val_measures.reset()

Lightning hook that is called when a validation epoch ends.

def on_test_epoch_start(self) -> None:
243    def on_test_epoch_start(self) -> None:
244        """Lightning hook that is called before a test epoch begins."""
245        if self.test_measures:
246            if self.hparams.compile:
247                self.test_measures.register_net(self.compiled_net)
248            else:
249                self.test_measures.register_net(self.net)

Lightning hook that is called before a test epoch begins.

def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
251    def test_step(
252        self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
253    ) -> None:
254        """Perform a single test step on a batch of data from the test set.
255
256        :param batch: A batch of data (a tuple) containing the input tensor of
257            images and target labels.
258        :param batch_idx: The index of the current batch.
259        """
260        losses, out = self.model_step(batch)
261        preds = out["logits"]
262        preds_masked = out["logits_masked"]
263        targets = out["targets"]
264        maps = out["masks"]
265
266        # update and log metrics
267        self.test_loss.update(losses[0])
268        self.log("test/loss", self.test_loss)
269        self.test_acc.update(preds_masked, targets)
270        self.log("test/acc", self.test_acc)
271        if save_masks := self.hparams.get("save_masks"):
272            Path(save_masks).mkdir(parents=True, exist_ok=True)
273            torch.save(maps, f"{save_masks}/{batch_idx}.pt")
274        if self.test_measures:
275            self.test_measures.update(batch[0], preds, targets, maps)

Perform a single test step on a batch of data from the test set.

Parameters
  • batch: A batch of data (a tuple) containing the input tensor of images and target labels.
  • batch_idx: The index of the current batch.
def on_test_epoch_end(self) -> None:
277    def on_test_epoch_end(self) -> None:
278        """Lightning hook that is called when a test epoch ends."""
279        if self.hparams.feature_contribution:
280            layer_names, contribs = self.net.attention.get_contributions()
281            contribs_dict = {
282                layer: contrib
283                for layer, contrib in zip(layer_names, contribs.mean(dim=0))
284            }
285            self.log_dict(contribs_dict)
286
287        if self.test_measures:
288            self.log_dict(self.test_measures.compute())
289            self.test_measures.reset()

Lightning hook that is called when a test epoch ends.

def setup(self, stage: str) -> None:
291    def setup(self, stage: str) -> None:
292        """Lightning hook that is called at the beginning of fit (train +
293        validate), validate, test, or predict.
294
295        This is a good hook when you need to build models dynamically or adjust
296        something about them. This hook is called on every process when using
297        DDP.
298
299        :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
300        """
301
302        if self.hparams.compile:
303            self.compiled_net = torch.compile(self.net)
304        # metric objects for calculating and averaging accuracy across batches
305        self.train_acc = Accuracy(
306            task="multiclass", num_classes=self.trainer.datamodule.num_classes
307        )
308        self.val_acc = Accuracy(
309            task="multiclass", num_classes=self.trainer.datamodule.num_classes
310        )
311        self.test_acc = Accuracy(
312            task="multiclass", num_classes=self.trainer.datamodule.num_classes
313        )

Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.

This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters
  • stage: Either "fit", "validate", "test", or "predict".
def teardown(self, stage: str) -> None:
315    def teardown(self, stage: str) -> None:
316        """Lightning hook that is called at the end of fit (train + validate),
317        validate, test, or predict.
318
319        This is a good hook when you need to clean something up after the run.
320
321        :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
322        """
323        if self.hparams.compile:
324            del self.compiled_net

Lightning hook that is called at the end of fit (train + validate), validate, test, or predict.

This is a good hook when you need to clean something up after the run.

Parameters
  • stage: Either "fit", "validate", "test", or "predict".
def configure_optimizers(self) -> dict[str, typing.Any]:
326    def configure_optimizers(self) -> dict[str, Any]:
327        """Choose what optimizers and learning-rate schedulers to use in your
328        optimization. Normally you'd need one. But in the case of GANs or
329        similar you might have multiple.
330
331        Examples:
332            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
333
334        :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
335        """
336        optimizer = self.optimizer_cfg(params=self.trainer.model.parameters())
337        if (scheduler_info := self.scheduler_cfg) is not None:
338            scheduler = scheduler_info["scheduler"]
339            if scheduler_info["name"] == "OneCycleLR":
340                scheduler = partial(
341                    scheduler,
342                    total_steps=self.trainer.estimated_stepping_batches,
343                )
344            scheduler = scheduler(optimizer=optimizer)
345            scheduler_params = scheduler_info["params"]
346            return {
347                "optimizer": optimizer,
348                "lr_scheduler": {
349                    "scheduler": scheduler,
350                    "interval": scheduler_params["interval"],
351                    "frequency": scheduler_params["frequency"],
352                    "name": scheduler_params["name"],
353                },
354            }
355        return {"optimizer": optimizer}

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple.

Examples: https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers

Returns

A dict containing the configured optimizers and learning-rate schedulers to be used for training.

def on_save_checkpoint(self, checkpoint):
357    def on_save_checkpoint(self, checkpoint):
358        # pop the keys you are not interested by
359        sd = checkpoint["state_dict"]
360        names = list(sd.keys())
361        for name in names:
362            if "backbone" in name or "_orig_mod" in name:
363                sd.pop(name)
364            if "net.attention.attention" in name:
365                sd.pop(name)

Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.

Args: checkpoint: The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.

Example::

def on_save_checkpoint(self, checkpoint):
    # 99% of use cases you don't need to implement this method
    checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object

Note: Lightning saves all aspects of training (epoch, global step, etc...) including amp scaling. There is no need for you to store anything about training.