ptame.ax_sweep
1import os 2from typing import Any, Dict, List, Optional, Tuple 3 4import hydra 5import lightning as L 6from ax.service.ax_client import AxClient 7from lightning import Callback, LightningDataModule, LightningModule, Trainer 8from lightning.pytorch.loggers import Logger 9from omegaconf import DictConfig, OmegaConf 10import torch 11 12from ptame.utils import ( # noqa: E402 13 RankedLogger, 14 ax_wrapper, 15 extras, 16 instantiate_callbacks, 17 instantiate_loggers, 18 log_hyperparameters, 19) 20 21log = RankedLogger(__name__, rank_zero_only=True) 22 23 24@ax_wrapper 25def evaluation_function( 26 cfg: DictConfig, **kwargs: Any 27) -> Tuple[Dict[str, Any], bool]: 28 """Trains the model. Can additionally evaluate on a testset, using best 29 weights obtained during training. 30 31 This method is wrapped in optional @task_wrapper decorator, that controls 32 the behavior during failure. Useful for multiruns, saving info about the 33 crash, etc. 34 35 :param cfg: A DictConfig configuration composed by Hydra. 36 :return: A tuple with metrics and dict with all instantiated objects. 37 """ 38 # set seed for random number generators in pytorch, numpy and python.random 39 if cfg.get("seed"): 40 L.seed_everything(cfg.seed, workers=True) 41 42 log.info(f"Instantiating datamodule <{cfg.data._target_}>") 43 datamodule: LightningDataModule = hydra.utils.instantiate( 44 cfg.data, test=False 45 ) 46 47 log.info(f"Instantiating model <{cfg.model._target_}>") 48 model: LightningModule = hydra.utils.instantiate(cfg.model) 49 50 if aux_path := cfg.get("aux_ckpt_path"): 51 log.info(f"Loading auxiliary model from checkpoint path {aux_path}") 52 aux_ckpt = torch.load(aux_path, weights_only=True)["state_dict"] 53 aux_ckpt = {k.replace("net.", ""): v for k, v in aux_ckpt.items()} 54 incompatible_keys = model.net.attention.attention.load_state_dict( 55 aux_ckpt, 56 strict=False, 57 ) 58 # check if the model is in eval mode and has no gradients 59 log.info( 60 f"Aux. model in eval mode: {not model.net.attention.attention.training}" 61 ) 62 log.info( 63 f"Aux. model has no gradients: {not any(p.requires_grad for p in model.net.attention.attention.parameters())}" 64 ) 65 # get compatible keys 66 concrete_keys = list( 67 set(model.net.attention.attention.state_dict().keys()) 68 - set(incompatible_keys[0]) 69 ) 70 display_keys = ( 71 ( 72 "Compatible keys: " 73 + concrete_keys[:10] 74 + [f"... {len(concrete_keys) - 10} more items remaining"] 75 if len(concrete_keys) > 10 76 else concrete_keys 77 ) 78 if len(incompatible_keys[0]) > 10 79 else f"Incompatible keys: {incompatible_keys[0]}" 80 ) 81 log.info(display_keys) 82 log.info(f"Unexpected keys: {incompatible_keys[1]}") 83 84 log.info("Instantiating callbacks...") 85 callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) 86 87 log.info("Instantiating loggers...") 88 logger: List[Logger] = instantiate_loggers(cfg.get("logger")) 89 90 log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 91 trainer: Trainer = hydra.utils.instantiate( 92 cfg.trainer, callbacks=callbacks, logger=logger 93 ) 94 hydra.utils.instantiate(cfg.matmul) 95 96 object_dict = { 97 "cfg": cfg, 98 "datamodule": datamodule, 99 "model": model, 100 "callbacks": callbacks, 101 "logger": logger, 102 "trainer": trainer, 103 } 104 105 if logger: 106 log.info("Logging hyperparameters!") 107 log_hyperparameters(object_dict) 108 109 log.info("Starting training!") 110 trainer.fit(model=model, datamodule=datamodule) 111 metric_dict = {} 112 if trainer.is_last_batch: 113 # evaluate the model with validation set 114 trainer.test(model=model, datamodule=datamodule) 115 metrics = trainer.callback_metrics 116 # choose the metrics that are in the objectives_map list 117 metric_dict = { 118 name: ( 119 metrics[key].item(), 120 None, 121 ) 122 for name, key in cfg.hparams_search.objectives.mapping.items() 123 } 124 return metric_dict, trainer.is_last_batch 125 126 127def make_params(params: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]: 128 """Creates a list of dictionaries with the parameters to optimize.""" 129 parameters = [] 130 for name, param in params.items(): 131 param["name"] = name 132 parameters.append(param) 133 return parameters 134 135 136def merge_config(cfg, update): 137 """Merges the config with the update.""" 138 if reparam := cfg.hparams_search.search_space.get( 139 "composition_constraint", None 140 ): 141 total = reparam.total 142 independent = reparam.independent 143 target = reparam.target 144 update[target] = total - sum(update[i] for i in independent) 145 for key, value in update.items(): 146 keys = key.split(".") 147 d = cfg 148 for k in keys[:-1]: 149 d = d[k] 150 d[keys[-1]] = value 151 152 153@hydra.main( 154 version_base="1.3", config_path="configs", config_name="ax_sweep.yaml" 155) 156def main(cfg: DictConfig) -> Optional[float]: 157 """Main entry point for sweeping with ax. 158 159 :param cfg: DictConfig configuration composed by Hydra. 160 :return: Optional[float] with optimized metric value. 161 """ 162 # apply extra utilities 163 # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 164 extras(cfg) 165 worst_result = { 166 key: tuple(value) 167 for key, value in cfg.hparams_search.objectives.worst_result.items() 168 } 169 170 objectives = hydra.utils.instantiate(cfg.hparams_search.objectives.config) 171 parameters = make_params( 172 OmegaConf.to_container( 173 hydra.utils.instantiate(cfg.hparams_search.search_space.params) 174 ) 175 ) 176 status_quo = cfg.hparams_search.search_space.get("status_quo", None) 177 parameter_constraints = cfg.hparams_search.search_space.get( 178 "parameter_constraints", [] 179 ) 180 get_status_quo = False 181 if cfg.get("resume_ax"): 182 ax_client = AxClient.load_from_json_file(cfg.resume_ax) 183 if ( 184 not ( 185 ax_client.get_trials_data_frame()["arm_name"] == "status_quo" 186 ).any() 187 and status_quo 188 ): 189 if ax_client.experiment.status_quo is None: 190 ax_client.set_status_quo(status_quo) 191 params, trial_index = ax_client.attach_trial( 192 status_quo, arm_name="status_quo" 193 ) 194 get_status_quo = True 195 ax_client.set_optimization_config(objectives) 196 ax_client.set_search_space( 197 parameters, parameter_constraints=parameter_constraints 198 ) 199 else: 200 ax_client = AxClient(random_seed=cfg.get("seed")) 201 ax_client.create_experiment( 202 name=cfg.hparams_search.get("name"), 203 parameters=parameters, 204 parameter_constraints=parameter_constraints, 205 status_quo=status_quo, 206 objectives=objectives, 207 choose_generation_strategy_kwargs={"max_initialization_trials": 5}, 208 immutable_search_space_and_opt_config=False, 209 ) 210 if status_quo: 211 get_status_quo = True 212 params, trial_index = ax_client.attach_trial( 213 status_quo, arm_name="status_quo" 214 ) 215 output_dir = cfg.paths.output_dir 216 for i in range(cfg.hparams_search.get("max_trials")): 217 if get_status_quo: 218 get_status_quo = False 219 else: 220 params, trial_index = ax_client.get_next_trial() 221 cfg.paths.output_dir = output_dir + f"/trial_{trial_index}" 222 os.makedirs(cfg.paths.output_dir, exist_ok=True) 223 merge_config(cfg, params) 224 # train the model 225 if cfg.get("logger") and cfg.logger.get("wandb"): 226 cfg.logger.wandb.name = f"trial_{trial_index}" 227 metric_dict, status = evaluation_function( 228 cfg, trial=(trial_index, ax_client), output_dir=output_dir 229 ) 230 if status: 231 ax_client.complete_trial(trial_index, metric_dict) 232 else: 233 ax_client.stop_trial_early(trial_index) 234 ax_client.update_trial_data(trial_index, worst_result) 235 # save client 236 ax_client.save_to_json_file(output_dir + "/ax_client.json") 237 238 239if __name__ == "__main__": 240 main()
log =
<RankedLogger ptame.ax_sweep (INFO)>
def
evaluation_function( cfg: omegaconf.dictconfig.DictConfig, trial: Tuple[int, ax.service.ax_client.AxClient], output_dir='./') -> Tuple[Dict[str, Any], bool]:
129 def wrap( 130 cfg: DictConfig, trial: Tuple[int, AxClient], output_dir="./" 131 ) -> Tuple[Dict[str, Any], bool]: 132 metric_dict = {} 133 status = False 134 # execute the task 135 try: 136 metric_dict, status = task_func(cfg=cfg) 137 138 except (KeyboardInterrupt, RuntimeError) as ex: 139 log.exception( 140 f"Prematurely killed at time: {time.strftime('%Y-%m-%d %H:%M:%S')}" 141 ) 142 # save experiment 143 trial[1].log_trial_failure(trial[0]) 144 for i in range(5): 145 try: 146 trial[1].save_to_json_file( 147 f"{output_dir}/killed_experiment.json" 148 ) 149 break 150 except Exception as ex: 151 if i == 4: 152 log.exception( 153 f"Could not save experiment to json file! <{type(ex).__name__}>" 154 ) 155 156 # the above needs some time to save the file, but it is not blocking 157 # so we need to wait a bit before exiting 158 time.sleep(5) 159 raise ex 160 # things to do if exception occurs 161 except Exception as ex: 162 log.exception(f"Exception occurred! type: <{type(ex).__name__}>") 163 # save exception to `.log` file 164 # some hyperparameter combinations might be invalid or cause out-of-memory errors 165 # so when using hparam search plugins like Optuna, you might want to disable 166 # raising the below exception to avoid multirun failure 167 # things to always do after either success or exception 168 finally: 169 # display output dir path in terminal 170 log.info(f"Output dir: {cfg.paths.output_dir}") 171 172 # always close wandb run (even if exception occurs so multirun won't fail) 173 if find_spec("wandb"): # check if wandb is installed 174 import wandb 175 176 if wandb.run: 177 log.info("Closing wandb!") 178 exit_code = 0 if status else 1 179 wandb.finish(exit_code=exit_code) 180 181 return metric_dict, status
Trains the model. Can additionally evaluate on a testset, using best weights obtained during training.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc.
Parameters
- cfg: A DictConfig configuration composed by Hydra.
Returns
A tuple with metrics and dict with all instantiated objects.
def
make_params(params: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
128def make_params(params: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]: 129 """Creates a list of dictionaries with the parameters to optimize.""" 130 parameters = [] 131 for name, param in params.items(): 132 param["name"] = name 133 parameters.append(param) 134 return parameters
Creates a list of dictionaries with the parameters to optimize.
def
merge_config(cfg, update):
137def merge_config(cfg, update): 138 """Merges the config with the update.""" 139 if reparam := cfg.hparams_search.search_space.get( 140 "composition_constraint", None 141 ): 142 total = reparam.total 143 independent = reparam.independent 144 target = reparam.target 145 update[target] = total - sum(update[i] for i in independent) 146 for key, value in update.items(): 147 keys = key.split(".") 148 d = cfg 149 for k in keys[:-1]: 150 d = d[k] 151 d[keys[-1]] = value
Merges the config with the update.
@hydra.main(version_base='1.3', config_path='configs', config_name='ax_sweep.yaml')
def
main(cfg: omegaconf.dictconfig.DictConfig) -> Optional[float]:
154@hydra.main( 155 version_base="1.3", config_path="configs", config_name="ax_sweep.yaml" 156) 157def main(cfg: DictConfig) -> Optional[float]: 158 """Main entry point for sweeping with ax. 159 160 :param cfg: DictConfig configuration composed by Hydra. 161 :return: Optional[float] with optimized metric value. 162 """ 163 # apply extra utilities 164 # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 165 extras(cfg) 166 worst_result = { 167 key: tuple(value) 168 for key, value in cfg.hparams_search.objectives.worst_result.items() 169 } 170 171 objectives = hydra.utils.instantiate(cfg.hparams_search.objectives.config) 172 parameters = make_params( 173 OmegaConf.to_container( 174 hydra.utils.instantiate(cfg.hparams_search.search_space.params) 175 ) 176 ) 177 status_quo = cfg.hparams_search.search_space.get("status_quo", None) 178 parameter_constraints = cfg.hparams_search.search_space.get( 179 "parameter_constraints", [] 180 ) 181 get_status_quo = False 182 if cfg.get("resume_ax"): 183 ax_client = AxClient.load_from_json_file(cfg.resume_ax) 184 if ( 185 not ( 186 ax_client.get_trials_data_frame()["arm_name"] == "status_quo" 187 ).any() 188 and status_quo 189 ): 190 if ax_client.experiment.status_quo is None: 191 ax_client.set_status_quo(status_quo) 192 params, trial_index = ax_client.attach_trial( 193 status_quo, arm_name="status_quo" 194 ) 195 get_status_quo = True 196 ax_client.set_optimization_config(objectives) 197 ax_client.set_search_space( 198 parameters, parameter_constraints=parameter_constraints 199 ) 200 else: 201 ax_client = AxClient(random_seed=cfg.get("seed")) 202 ax_client.create_experiment( 203 name=cfg.hparams_search.get("name"), 204 parameters=parameters, 205 parameter_constraints=parameter_constraints, 206 status_quo=status_quo, 207 objectives=objectives, 208 choose_generation_strategy_kwargs={"max_initialization_trials": 5}, 209 immutable_search_space_and_opt_config=False, 210 ) 211 if status_quo: 212 get_status_quo = True 213 params, trial_index = ax_client.attach_trial( 214 status_quo, arm_name="status_quo" 215 ) 216 output_dir = cfg.paths.output_dir 217 for i in range(cfg.hparams_search.get("max_trials")): 218 if get_status_quo: 219 get_status_quo = False 220 else: 221 params, trial_index = ax_client.get_next_trial() 222 cfg.paths.output_dir = output_dir + f"/trial_{trial_index}" 223 os.makedirs(cfg.paths.output_dir, exist_ok=True) 224 merge_config(cfg, params) 225 # train the model 226 if cfg.get("logger") and cfg.logger.get("wandb"): 227 cfg.logger.wandb.name = f"trial_{trial_index}" 228 metric_dict, status = evaluation_function( 229 cfg, trial=(trial_index, ax_client), output_dir=output_dir 230 ) 231 if status: 232 ax_client.complete_trial(trial_index, metric_dict) 233 else: 234 ax_client.stop_trial_early(trial_index) 235 ax_client.update_trial_data(trial_index, worst_result) 236 # save client 237 ax_client.save_to_json_file(output_dir + "/ax_client.json")
Main entry point for sweeping with ax.
Parameters
- cfg: DictConfig configuration composed by Hydra.
Returns
Optional[float] with optimized metric value.