ptame.train
1from typing import Any, Dict, List, Optional, Tuple 2 3import hydra 4import lightning as L 5import torch 6from lightning import Callback, LightningDataModule, LightningModule, Trainer 7from lightning.pytorch.loggers import Logger 8from omegaconf import DictConfig 9 10from ptame.utils import ( # noqa: E402 11 RankedLogger, 12 extras, 13 get_metric_value, 14 instantiate_callbacks, 15 instantiate_loggers, 16 log_hyperparameters, 17 task_wrapper, 18) 19 20log = RankedLogger(__name__, rank_zero_only=True) 21 22 23@task_wrapper 24def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 25 """Trains the model. Can additionally evaluate on a testset, using best 26 weights obtained during training. 27 28 This method is wrapped in optional @task_wrapper decorator, that controls 29 the behavior during failure. Useful for multiruns, saving info about the 30 crash, etc. 31 32 :param cfg: A DictConfig configuration composed by Hydra. 33 :return: A tuple with metrics and dict with all instantiated objects. 34 """ 35 # set seed for random number generators in pytorch, numpy and python.random 36 if cfg.get("seed"): 37 L.seed_everything(cfg.seed, workers=True) 38 39 log.info(f"Instantiating datamodule <{cfg.data._target_}>") 40 datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 41 log.info(f"Instantiating model <{cfg.model._target_}>") 42 model: LightningModule = hydra.utils.instantiate(cfg.model) 43 44 if ckpt_path := cfg.get("ckpt_path"): 45 log.info(f"Loading model from checkpoint path {ckpt_path}") 46 incompatible_keys = model.load_state_dict( 47 torch.load(ckpt_path, weights_only=True)["state_dict"], 48 strict=False, 49 ) 50 # get compatible keys 51 concrete_keys = model.state_dict().keys() 52 log.info( 53 f"Compatible keys: {set(concrete_keys) - set(incompatible_keys[0])}" 54 ) 55 log.info(f"Unexpected keys: {incompatible_keys[1]}") 56 57 if aux_path := cfg.get("aux_ckpt_path"): 58 log.info(f"Loading auxiliary model from checkpoint path {aux_path}") 59 aux_ckpt = torch.load(aux_path, weights_only=True)["state_dict"] 60 aux_ckpt = { 61 k.replace("net.", ""): v 62 for k, v in aux_ckpt.items() 63 if "fc" not in k 64 } 65 incompatible_keys = model.net.attention.attention.load_state_dict( 66 aux_ckpt, 67 strict=False, 68 ) 69 # check if the model is in eval mode and has no gradients 70 log.info( 71 f"Aux. model in eval mode: {not model.net.attention.attention.training}" 72 ) 73 log.info( 74 f"Aux. model has no gradients: {not any(p.requires_grad for p in model.net.attention.attention.parameters())}" 75 ) 76 # get compatible keys 77 concrete_keys = list( 78 set(model.net.attention.attention.state_dict().keys()) 79 - set(incompatible_keys[0]) 80 ) 81 display_keys = ( 82 ( 83 "Compatible keys: " 84 + " ".join(concrete_keys[:10]) 85 + f"... {len(concrete_keys) - 10} more items remaining" 86 if len(concrete_keys) > 10 87 else concrete_keys 88 ) 89 if len(incompatible_keys[0]) > 10 90 else f"Incompatible keys: {incompatible_keys[0]}" 91 ) 92 log.info(display_keys) 93 log.info(f"Unexpected keys: {incompatible_keys[1]}") 94 95 log.info("Instantiating callbacks...") 96 callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) 97 98 log.info("Instantiating loggers...") 99 logger: List[Logger] = instantiate_loggers(cfg.get("logger")) 100 101 log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 102 trainer: Trainer = hydra.utils.instantiate( 103 cfg.trainer, callbacks=callbacks, logger=logger 104 ) 105 hydra.utils.instantiate(cfg.matmul) 106 107 object_dict = { 108 "cfg": cfg, 109 "datamodule": datamodule, 110 "model": model, 111 "callbacks": callbacks, 112 "logger": logger, 113 "trainer": trainer, 114 } 115 116 if logger: 117 log.info("Logging hyperparameters!") 118 trainable = log_hyperparameters(object_dict) 119 if not trainable: 120 log.warning("Exiting training due to non-trainable model.") 121 return {}, object_dict 122 123 if cfg.get("train"): 124 log.info("Starting training!") 125 trainer.fit(model=model, datamodule=datamodule) 126 127 train_metrics = trainer.callback_metrics 128 129 if cfg.get("test"): 130 log.info("Starting testing!") 131 trainer.test(model=model, datamodule=datamodule) 132 133 test_metrics = trainer.callback_metrics 134 135 # merge train and test metrics 136 metric_dict = {**train_metrics, **test_metrics} 137 138 return metric_dict, object_dict 139 140 141@hydra.main( 142 version_base="1.3", config_path="configs", config_name="train.yaml" 143) 144def main(cfg: DictConfig) -> Optional[float]: 145 """Main entry point for training. 146 147 :param cfg: DictConfig configuration composed by Hydra. 148 :return: Optional[float] with optimized metric value. 149 """ 150 # apply extra utilities 151 # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 152 extras(cfg) 153 154 # train the model 155 metric_dict, _ = train(cfg) 156 157 # safely retrieve metric value for hydra-based hyperparameter optimization 158 metric_value = get_metric_value( 159 metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") 160 ) 161 162 # return optimized metric 163 return metric_value 164 165 166if __name__ == "__main__": 167 main()
log =
<RankedLogger ptame.train (INFO)>
def
train( cfg: omegaconf.dictconfig.DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
73 def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 74 # execute the task 75 exit_code = 1 76 try: 77 metric_dict, object_dict = task_func(cfg=cfg) 78 exit_code = 0 79 # things to do if exception occurs 80 except Exception as ex: 81 # save exception to `.log` file 82 log.exception("") 83 # some hyperparameter combinations might be invalid or cause out-of-memory errors 84 # so when using hparam search plugins like Optuna, you might want to disable 85 # raising the below exception to avoid multirun failure 86 raise ex 87 88 # things to always do after either success or exception 89 finally: 90 # display output dir path in terminal 91 log.info(f"Output dir: {cfg.paths.output_dir}") 92 93 # always close wandb run (even if exception occurs so multirun won't fail) 94 if find_spec("wandb"): # check if wandb is installed 95 import wandb 96 97 if wandb.run: 98 log.info("Closing wandb!") 99 wandb.finish(exit_code=exit_code) 100 101 return metric_dict, object_dict
Trains the model. Can additionally evaluate on a testset, using best weights obtained during training.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc.
Parameters
- cfg: A DictConfig configuration composed by Hydra.
Returns
A tuple with metrics and dict with all instantiated objects.
@hydra.main(version_base='1.3', config_path='configs', config_name='train.yaml')
def
main(cfg: omegaconf.dictconfig.DictConfig) -> Optional[float]:
142@hydra.main( 143 version_base="1.3", config_path="configs", config_name="train.yaml" 144) 145def main(cfg: DictConfig) -> Optional[float]: 146 """Main entry point for training. 147 148 :param cfg: DictConfig configuration composed by Hydra. 149 :return: Optional[float] with optimized metric value. 150 """ 151 # apply extra utilities 152 # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 153 extras(cfg) 154 155 # train the model 156 metric_dict, _ = train(cfg) 157 158 # safely retrieve metric value for hydra-based hyperparameter optimization 159 metric_value = get_metric_value( 160 metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") 161 ) 162 163 # return optimized metric 164 return metric_value
Main entry point for training.
Parameters
- cfg: DictConfig configuration composed by Hydra.
Returns
Optional[float] with optimized metric value.