ptame.train

  1from typing import Any, Dict, List, Optional, Tuple
  2
  3import hydra
  4import lightning as L
  5import torch
  6from lightning import Callback, LightningDataModule, LightningModule, Trainer
  7from lightning.pytorch.loggers import Logger
  8from omegaconf import DictConfig
  9
 10from ptame.utils import (  # noqa: E402
 11    RankedLogger,
 12    extras,
 13    get_metric_value,
 14    instantiate_callbacks,
 15    instantiate_loggers,
 16    log_hyperparameters,
 17    task_wrapper,
 18)
 19
 20log = RankedLogger(__name__, rank_zero_only=True)
 21
 22
 23@task_wrapper
 24def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
 25    """Trains the model. Can additionally evaluate on a testset, using best
 26    weights obtained during training.
 27
 28    This method is wrapped in optional @task_wrapper decorator, that controls
 29    the behavior during failure. Useful for multiruns, saving info about the
 30    crash, etc.
 31
 32    :param cfg: A DictConfig configuration composed by Hydra.
 33    :return: A tuple with metrics and dict with all instantiated objects.
 34    """
 35    # set seed for random number generators in pytorch, numpy and python.random
 36    if cfg.get("seed"):
 37        L.seed_everything(cfg.seed, workers=True)
 38
 39    log.info(f"Instantiating datamodule <{cfg.data._target_}>")
 40    datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
 41    log.info(f"Instantiating model <{cfg.model._target_}>")
 42    model: LightningModule = hydra.utils.instantiate(cfg.model)
 43
 44    if ckpt_path := cfg.get("ckpt_path"):
 45        log.info(f"Loading model from checkpoint path {ckpt_path}")
 46        incompatible_keys = model.load_state_dict(
 47            torch.load(ckpt_path, weights_only=True)["state_dict"],
 48            strict=False,
 49        )
 50        # get compatible keys
 51        concrete_keys = model.state_dict().keys()
 52        log.info(
 53            f"Compatible keys: {set(concrete_keys) - set(incompatible_keys[0])}"
 54        )
 55        log.info(f"Unexpected keys: {incompatible_keys[1]}")
 56
 57    if aux_path := cfg.get("aux_ckpt_path"):
 58        log.info(f"Loading auxiliary model from checkpoint path {aux_path}")
 59        aux_ckpt = torch.load(aux_path, weights_only=True)["state_dict"]
 60        aux_ckpt = {
 61            k.replace("net.", ""): v
 62            for k, v in aux_ckpt.items()
 63            if "fc" not in k
 64        }
 65        incompatible_keys = model.net.attention.attention.load_state_dict(
 66            aux_ckpt,
 67            strict=False,
 68        )
 69        # check if the model is in eval mode and has no gradients
 70        log.info(
 71            f"Aux. model in eval mode: {not model.net.attention.attention.training}"
 72        )
 73        log.info(
 74            f"Aux. model has no gradients: {not any(p.requires_grad for p in model.net.attention.attention.parameters())}"
 75        )
 76        # get compatible keys
 77        concrete_keys = list(
 78            set(model.net.attention.attention.state_dict().keys())
 79            - set(incompatible_keys[0])
 80        )
 81        display_keys = (
 82            (
 83                "Compatible keys: "
 84                + " ".join(concrete_keys[:10])
 85                + f"... {len(concrete_keys) - 10} more items remaining"
 86                if len(concrete_keys) > 10
 87                else concrete_keys
 88            )
 89            if len(incompatible_keys[0]) > 10
 90            else f"Incompatible keys: {incompatible_keys[0]}"
 91        )
 92        log.info(display_keys)
 93        log.info(f"Unexpected keys: {incompatible_keys[1]}")
 94
 95    log.info("Instantiating callbacks...")
 96    callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
 97
 98    log.info("Instantiating loggers...")
 99    logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
100
101    log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
102    trainer: Trainer = hydra.utils.instantiate(
103        cfg.trainer, callbacks=callbacks, logger=logger
104    )
105    hydra.utils.instantiate(cfg.matmul)
106
107    object_dict = {
108        "cfg": cfg,
109        "datamodule": datamodule,
110        "model": model,
111        "callbacks": callbacks,
112        "logger": logger,
113        "trainer": trainer,
114    }
115
116    if logger:
117        log.info("Logging hyperparameters!")
118        trainable = log_hyperparameters(object_dict)
119        if not trainable:
120            log.warning("Exiting training due to non-trainable model.")
121            return {}, object_dict
122
123    if cfg.get("train"):
124        log.info("Starting training!")
125        trainer.fit(model=model, datamodule=datamodule)
126
127    train_metrics = trainer.callback_metrics
128
129    if cfg.get("test"):
130        log.info("Starting testing!")
131        trainer.test(model=model, datamodule=datamodule)
132
133    test_metrics = trainer.callback_metrics
134
135    # merge train and test metrics
136    metric_dict = {**train_metrics, **test_metrics}
137
138    return metric_dict, object_dict
139
140
141@hydra.main(
142    version_base="1.3", config_path="configs", config_name="train.yaml"
143)
144def main(cfg: DictConfig) -> Optional[float]:
145    """Main entry point for training.
146
147    :param cfg: DictConfig configuration composed by Hydra.
148    :return: Optional[float] with optimized metric value.
149    """
150    # apply extra utilities
151    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
152    extras(cfg)
153
154    # train the model
155    metric_dict, _ = train(cfg)
156
157    # safely retrieve metric value for hydra-based hyperparameter optimization
158    metric_value = get_metric_value(
159        metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
160    )
161
162    # return optimized metric
163    return metric_value
164
165
166if __name__ == "__main__":
167    main()
log = <RankedLogger ptame.train (INFO)>
def train( cfg: omegaconf.dictconfig.DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
 73    def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
 74        # execute the task
 75        exit_code = 1
 76        try:
 77            metric_dict, object_dict = task_func(cfg=cfg)
 78            exit_code = 0
 79        # things to do if exception occurs
 80        except Exception as ex:
 81            # save exception to `.log` file
 82            log.exception("")
 83            # some hyperparameter combinations might be invalid or cause out-of-memory errors
 84            # so when using hparam search plugins like Optuna, you might want to disable
 85            # raising the below exception to avoid multirun failure
 86            raise ex
 87
 88        # things to always do after either success or exception
 89        finally:
 90            # display output dir path in terminal
 91            log.info(f"Output dir: {cfg.paths.output_dir}")
 92
 93            # always close wandb run (even if exception occurs so multirun won't fail)
 94            if find_spec("wandb"):  # check if wandb is installed
 95                import wandb
 96
 97                if wandb.run:
 98                    log.info("Closing wandb!")
 99                    wandb.finish(exit_code=exit_code)
100
101        return metric_dict, object_dict

Trains the model. Can additionally evaluate on a testset, using best weights obtained during training.

This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc.

Parameters
  • cfg: A DictConfig configuration composed by Hydra.
Returns

A tuple with metrics and dict with all instantiated objects.

@hydra.main(version_base='1.3', config_path='configs', config_name='train.yaml')
def main(cfg: omegaconf.dictconfig.DictConfig) -> Optional[float]:
142@hydra.main(
143    version_base="1.3", config_path="configs", config_name="train.yaml"
144)
145def main(cfg: DictConfig) -> Optional[float]:
146    """Main entry point for training.
147
148    :param cfg: DictConfig configuration composed by Hydra.
149    :return: Optional[float] with optimized metric value.
150    """
151    # apply extra utilities
152    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
153    extras(cfg)
154
155    # train the model
156    metric_dict, _ = train(cfg)
157
158    # safely retrieve metric value for hydra-based hyperparameter optimization
159    metric_value = get_metric_value(
160        metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
161    )
162
163    # return optimized metric
164    return metric_value

Main entry point for training.

Parameters
  • cfg: DictConfig configuration composed by Hydra.
Returns

Optional[float] with optimized metric value.