ptame.ax_sweep

  1import os
  2from typing import Any, Dict, List, Optional, Tuple
  3
  4import hydra
  5import lightning as L
  6from ax.service.ax_client import AxClient
  7from lightning import Callback, LightningDataModule, LightningModule, Trainer
  8from lightning.pytorch.loggers import Logger
  9from omegaconf import DictConfig, OmegaConf
 10import torch
 11
 12from ptame.utils import (  # noqa: E402
 13    RankedLogger,
 14    ax_wrapper,
 15    extras,
 16    instantiate_callbacks,
 17    instantiate_loggers,
 18    log_hyperparameters,
 19)
 20
 21log = RankedLogger(__name__, rank_zero_only=True)
 22
 23
 24@ax_wrapper
 25def evaluation_function(
 26    cfg: DictConfig, **kwargs: Any
 27) -> Tuple[Dict[str, Any], bool]:
 28    """Trains the model. Can additionally evaluate on a testset, using best
 29    weights obtained during training.
 30
 31    This method is wrapped in optional @task_wrapper decorator, that controls
 32    the behavior during failure. Useful for multiruns, saving info about the
 33    crash, etc.
 34
 35    :param cfg: A DictConfig configuration composed by Hydra.
 36    :return: A tuple with metrics and dict with all instantiated objects.
 37    """
 38    # set seed for random number generators in pytorch, numpy and python.random
 39    if cfg.get("seed"):
 40        L.seed_everything(cfg.seed, workers=True)
 41
 42    log.info(f"Instantiating datamodule <{cfg.data._target_}>")
 43    datamodule: LightningDataModule = hydra.utils.instantiate(
 44        cfg.data, test=False
 45    )
 46
 47    log.info(f"Instantiating model <{cfg.model._target_}>")
 48    model: LightningModule = hydra.utils.instantiate(cfg.model)
 49
 50    if aux_path := cfg.get("aux_ckpt_path"):
 51        log.info(f"Loading auxiliary model from checkpoint path {aux_path}")
 52        aux_ckpt = torch.load(aux_path, weights_only=True)["state_dict"]
 53        aux_ckpt = {k.replace("net.", ""): v for k, v in aux_ckpt.items()}
 54        incompatible_keys = model.net.attention.attention.load_state_dict(
 55            aux_ckpt,
 56            strict=False,
 57        )
 58        # check if the model is in eval mode and has no gradients
 59        log.info(
 60            f"Aux. model in eval mode: {not model.net.attention.attention.training}"
 61        )
 62        log.info(
 63            f"Aux. model has no gradients: {not any(p.requires_grad for p in model.net.attention.attention.parameters())}"
 64        )
 65        # get compatible keys
 66        concrete_keys = list(
 67            set(model.net.attention.attention.state_dict().keys())
 68            - set(incompatible_keys[0])
 69        )
 70        display_keys = (
 71            (
 72                "Compatible keys: "
 73                + concrete_keys[:10]
 74                + [f"... {len(concrete_keys) - 10} more items remaining"]
 75                if len(concrete_keys) > 10
 76                else concrete_keys
 77            )
 78            if len(incompatible_keys[0]) > 10
 79            else f"Incompatible keys: {incompatible_keys[0]}"
 80        )
 81        log.info(display_keys)
 82        log.info(f"Unexpected keys: {incompatible_keys[1]}")
 83
 84    log.info("Instantiating callbacks...")
 85    callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
 86
 87    log.info("Instantiating loggers...")
 88    logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
 89
 90    log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
 91    trainer: Trainer = hydra.utils.instantiate(
 92        cfg.trainer, callbacks=callbacks, logger=logger
 93    )
 94    hydra.utils.instantiate(cfg.matmul)
 95
 96    object_dict = {
 97        "cfg": cfg,
 98        "datamodule": datamodule,
 99        "model": model,
100        "callbacks": callbacks,
101        "logger": logger,
102        "trainer": trainer,
103    }
104
105    if logger:
106        log.info("Logging hyperparameters!")
107        log_hyperparameters(object_dict)
108
109    log.info("Starting training!")
110    trainer.fit(model=model, datamodule=datamodule)
111    metric_dict = {}
112    if trainer.is_last_batch:
113        # evaluate the model with validation set
114        trainer.test(model=model, datamodule=datamodule)
115        metrics = trainer.callback_metrics
116        # choose the metrics that are in the objectives_map list
117        metric_dict = {
118            name: (
119                metrics[key].item(),
120                None,
121            )
122            for name, key in cfg.hparams_search.objectives.mapping.items()
123        }
124    return metric_dict, trainer.is_last_batch
125
126
127def make_params(params: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
128    """Creates a list of dictionaries with the parameters to optimize."""
129    parameters = []
130    for name, param in params.items():
131        param["name"] = name
132        parameters.append(param)
133    return parameters
134
135
136def merge_config(cfg, update):
137    """Merges the config with the update."""
138    if reparam := cfg.hparams_search.search_space.get(
139        "composition_constraint", None
140    ):
141        total = reparam.total
142        independent = reparam.independent
143        target = reparam.target
144        update[target] = total - sum(update[i] for i in independent)
145    for key, value in update.items():
146        keys = key.split(".")
147        d = cfg
148        for k in keys[:-1]:
149            d = d[k]
150        d[keys[-1]] = value
151
152
153@hydra.main(
154    version_base="1.3", config_path="configs", config_name="ax_sweep.yaml"
155)
156def main(cfg: DictConfig) -> Optional[float]:
157    """Main entry point for sweeping with ax.
158
159    :param cfg: DictConfig configuration composed by Hydra.
160    :return: Optional[float] with optimized metric value.
161    """
162    # apply extra utilities
163    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
164    extras(cfg)
165    worst_result = {
166        key: tuple(value)
167        for key, value in cfg.hparams_search.objectives.worst_result.items()
168    }
169
170    objectives = hydra.utils.instantiate(cfg.hparams_search.objectives.config)
171    parameters = make_params(
172        OmegaConf.to_container(
173            hydra.utils.instantiate(cfg.hparams_search.search_space.params)
174        )
175    )
176    status_quo = cfg.hparams_search.search_space.get("status_quo", None)
177    parameter_constraints = cfg.hparams_search.search_space.get(
178        "parameter_constraints", []
179    )
180    get_status_quo = False
181    if cfg.get("resume_ax"):
182        ax_client = AxClient.load_from_json_file(cfg.resume_ax)
183        if (
184            not (
185                ax_client.get_trials_data_frame()["arm_name"] == "status_quo"
186            ).any()
187            and status_quo
188        ):
189            if ax_client.experiment.status_quo is None:
190                ax_client.set_status_quo(status_quo)
191            params, trial_index = ax_client.attach_trial(
192                status_quo, arm_name="status_quo"
193            )
194            get_status_quo = True
195        ax_client.set_optimization_config(objectives)
196        ax_client.set_search_space(
197            parameters, parameter_constraints=parameter_constraints
198        )
199    else:
200        ax_client = AxClient(random_seed=cfg.get("seed"))
201        ax_client.create_experiment(
202            name=cfg.hparams_search.get("name"),
203            parameters=parameters,
204            parameter_constraints=parameter_constraints,
205            status_quo=status_quo,
206            objectives=objectives,
207            choose_generation_strategy_kwargs={"max_initialization_trials": 5},
208            immutable_search_space_and_opt_config=False,
209        )
210        if status_quo:
211            get_status_quo = True
212            params, trial_index = ax_client.attach_trial(
213                status_quo, arm_name="status_quo"
214            )
215    output_dir = cfg.paths.output_dir
216    for i in range(cfg.hparams_search.get("max_trials")):
217        if get_status_quo:
218            get_status_quo = False
219        else:
220            params, trial_index = ax_client.get_next_trial()
221        cfg.paths.output_dir = output_dir + f"/trial_{trial_index}"
222        os.makedirs(cfg.paths.output_dir, exist_ok=True)
223        merge_config(cfg, params)
224        # train the model
225        if cfg.get("logger") and cfg.logger.get("wandb"):
226            cfg.logger.wandb.name = f"trial_{trial_index}"
227        metric_dict, status = evaluation_function(
228            cfg, trial=(trial_index, ax_client), output_dir=output_dir
229        )
230        if status:
231            ax_client.complete_trial(trial_index, metric_dict)
232        else:
233            ax_client.stop_trial_early(trial_index)
234            ax_client.update_trial_data(trial_index, worst_result)
235        # save client
236        ax_client.save_to_json_file(output_dir + "/ax_client.json")
237
238
239if __name__ == "__main__":
240    main()
log = <RankedLogger ptame.ax_sweep (INFO)>
def evaluation_function( cfg: omegaconf.dictconfig.DictConfig, trial: Tuple[int, ax.service.ax_client.AxClient], output_dir='./') -> Tuple[Dict[str, Any], bool]:
129    def wrap(
130        cfg: DictConfig, trial: Tuple[int, AxClient], output_dir="./"
131    ) -> Tuple[Dict[str, Any], bool]:
132        metric_dict = {}
133        status = False
134        # execute the task
135        try:
136            metric_dict, status = task_func(cfg=cfg)
137
138        except (KeyboardInterrupt, RuntimeError) as ex:
139            log.exception(
140                f"Prematurely killed at time: {time.strftime('%Y-%m-%d %H:%M:%S')}"
141            )
142            # save experiment
143            trial[1].log_trial_failure(trial[0])
144            for i in range(5):
145                try:
146                    trial[1].save_to_json_file(
147                        f"{output_dir}/killed_experiment.json"
148                    )
149                    break
150                except Exception as ex:
151                    if i == 4:
152                        log.exception(
153                            f"Could not save experiment to json file! <{type(ex).__name__}>"
154                        )
155
156            # the above needs some time to save the file, but it is not blocking
157            # so we need to wait a bit before exiting
158            time.sleep(5)
159            raise ex
160        # things to do if exception occurs
161        except Exception as ex:
162            log.exception(f"Exception occurred! type: <{type(ex).__name__}>")
163            # save exception to `.log` file
164            # some hyperparameter combinations might be invalid or cause out-of-memory errors
165            # so when using hparam search plugins like Optuna, you might want to disable
166            # raising the below exception to avoid multirun failure
167        # things to always do after either success or exception
168        finally:
169            # display output dir path in terminal
170            log.info(f"Output dir: {cfg.paths.output_dir}")
171
172            # always close wandb run (even if exception occurs so multirun won't fail)
173            if find_spec("wandb"):  # check if wandb is installed
174                import wandb
175
176                if wandb.run:
177                    log.info("Closing wandb!")
178                    exit_code = 0 if status else 1
179                    wandb.finish(exit_code=exit_code)
180
181        return metric_dict, status

Trains the model. Can additionally evaluate on a testset, using best weights obtained during training.

This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc.

Parameters
  • cfg: A DictConfig configuration composed by Hydra.
Returns

A tuple with metrics and dict with all instantiated objects.

def make_params(params: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
128def make_params(params: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
129    """Creates a list of dictionaries with the parameters to optimize."""
130    parameters = []
131    for name, param in params.items():
132        param["name"] = name
133        parameters.append(param)
134    return parameters

Creates a list of dictionaries with the parameters to optimize.

def merge_config(cfg, update):
137def merge_config(cfg, update):
138    """Merges the config with the update."""
139    if reparam := cfg.hparams_search.search_space.get(
140        "composition_constraint", None
141    ):
142        total = reparam.total
143        independent = reparam.independent
144        target = reparam.target
145        update[target] = total - sum(update[i] for i in independent)
146    for key, value in update.items():
147        keys = key.split(".")
148        d = cfg
149        for k in keys[:-1]:
150            d = d[k]
151        d[keys[-1]] = value

Merges the config with the update.

@hydra.main(version_base='1.3', config_path='configs', config_name='ax_sweep.yaml')
def main(cfg: omegaconf.dictconfig.DictConfig) -> Optional[float]:
154@hydra.main(
155    version_base="1.3", config_path="configs", config_name="ax_sweep.yaml"
156)
157def main(cfg: DictConfig) -> Optional[float]:
158    """Main entry point for sweeping with ax.
159
160    :param cfg: DictConfig configuration composed by Hydra.
161    :return: Optional[float] with optimized metric value.
162    """
163    # apply extra utilities
164    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
165    extras(cfg)
166    worst_result = {
167        key: tuple(value)
168        for key, value in cfg.hparams_search.objectives.worst_result.items()
169    }
170
171    objectives = hydra.utils.instantiate(cfg.hparams_search.objectives.config)
172    parameters = make_params(
173        OmegaConf.to_container(
174            hydra.utils.instantiate(cfg.hparams_search.search_space.params)
175        )
176    )
177    status_quo = cfg.hparams_search.search_space.get("status_quo", None)
178    parameter_constraints = cfg.hparams_search.search_space.get(
179        "parameter_constraints", []
180    )
181    get_status_quo = False
182    if cfg.get("resume_ax"):
183        ax_client = AxClient.load_from_json_file(cfg.resume_ax)
184        if (
185            not (
186                ax_client.get_trials_data_frame()["arm_name"] == "status_quo"
187            ).any()
188            and status_quo
189        ):
190            if ax_client.experiment.status_quo is None:
191                ax_client.set_status_quo(status_quo)
192            params, trial_index = ax_client.attach_trial(
193                status_quo, arm_name="status_quo"
194            )
195            get_status_quo = True
196        ax_client.set_optimization_config(objectives)
197        ax_client.set_search_space(
198            parameters, parameter_constraints=parameter_constraints
199        )
200    else:
201        ax_client = AxClient(random_seed=cfg.get("seed"))
202        ax_client.create_experiment(
203            name=cfg.hparams_search.get("name"),
204            parameters=parameters,
205            parameter_constraints=parameter_constraints,
206            status_quo=status_quo,
207            objectives=objectives,
208            choose_generation_strategy_kwargs={"max_initialization_trials": 5},
209            immutable_search_space_and_opt_config=False,
210        )
211        if status_quo:
212            get_status_quo = True
213            params, trial_index = ax_client.attach_trial(
214                status_quo, arm_name="status_quo"
215            )
216    output_dir = cfg.paths.output_dir
217    for i in range(cfg.hparams_search.get("max_trials")):
218        if get_status_quo:
219            get_status_quo = False
220        else:
221            params, trial_index = ax_client.get_next_trial()
222        cfg.paths.output_dir = output_dir + f"/trial_{trial_index}"
223        os.makedirs(cfg.paths.output_dir, exist_ok=True)
224        merge_config(cfg, params)
225        # train the model
226        if cfg.get("logger") and cfg.logger.get("wandb"):
227            cfg.logger.wandb.name = f"trial_{trial_index}"
228        metric_dict, status = evaluation_function(
229            cfg, trial=(trial_index, ax_client), output_dir=output_dir
230        )
231        if status:
232            ax_client.complete_trial(trial_index, metric_dict)
233        else:
234            ax_client.stop_trial_early(trial_index)
235            ax_client.update_trial_data(trial_index, worst_result)
236        # save client
237        ax_client.save_to_json_file(output_dir + "/ax_client.json")

Main entry point for sweeping with ax.

Parameters
  • cfg: DictConfig configuration composed by Hydra.
Returns

Optional[float] with optimized metric value.