ptame.utils
1from .instantiators import instantiate_callbacks, instantiate_loggers 2from .logging_utils import log_hyperparameters 3from .memory_format import MemoryFormat 4from .pylogger import RankedLogger 5from .rich_utils import enforce_tags, print_config_tree 6from .utils import ax_wrapper, extras, get_metric_value, task_wrapper 7 8__all__ = [ 9 "instantiate_callbacks", 10 "instantiate_loggers", 11 "log_hyperparameters", 12 "RankedLogger", 13 "enforce_tags", 14 "print_config_tree", 15 "ax_wrapper", 16 "extras", 17 "get_metric_value", 18 "task_wrapper", 19 "MemoryFormat", 20]
14def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 15 """Instantiates callbacks from config. 16 17 :param callbacks_cfg: A DictConfig object containing callback 18 configurations. 19 :return: A list of instantiated callbacks. 20 """ 21 callbacks: List[Callback] = [] 22 23 if not callbacks_cfg: 24 log.warning("No callback configs found! Skipping..") 25 return callbacks 26 27 if not isinstance(callbacks_cfg, DictConfig): 28 raise TypeError("Callbacks config must be a DictConfig!") 29 30 for _, cb_conf in callbacks_cfg.items(): 31 if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 32 log.info(f"Instantiating callback <{cb_conf._target_}>") 33 callbacks.append(hydra.utils.instantiate(cb_conf)) 34 35 return callbacks
Instantiates callbacks from config.
Parameters
- callbacks_cfg: A DictConfig object containing callback configurations.
Returns
A list of instantiated callbacks.
38def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 39 """Instantiates loggers from config. 40 41 :param logger_cfg: A DictConfig object containing logger configurations. 42 :return: A list of instantiated loggers. 43 """ 44 logger: List[Logger] = [] 45 46 if not logger_cfg: 47 log.warning("No logger configs found! Skipping...") 48 return logger 49 50 if not isinstance(logger_cfg, DictConfig): 51 raise TypeError("Logger config must be a DictConfig!") 52 53 for _, lg_conf in logger_cfg.items(): 54 if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 55 log.info(f"Instantiating logger <{lg_conf._target_}>") 56 logger.append(hydra.utils.instantiate(lg_conf)) 57 58 return logger
Instantiates loggers from config.
Parameters
- logger_cfg: A DictConfig object containing logger configurations.
Returns
A list of instantiated loggers.
15@rank_zero_only 16def log_hyperparameters(object_dict: dict[str, Any]) -> bool: 17 """Controls which config parts are saved by Lightning loggers. 18 19 Additionally saves: 20 - Number of model parameters 21 22 :param object_dict: A dictionary containing the following objects: 23 - `"cfg"`: A DictConfig object containing the main config. 24 - `"model"`: The Lightning model. 25 - `"trainer"`: The Lightning trainer. 26 """ 27 hparams = {} 28 29 cfg = OmegaConf.to_container(object_dict["cfg"], resolve=True) 30 model = object_dict["model"] 31 trainer = object_dict["trainer"] 32 33 try: 34 trainer.loggers 35 except AttributeError: 36 log.warning("Logger not found! Skipping hyperparameter logging...") 37 return 38 39 hparams["model"] = cfg["model"] 40 41 # save number of model parameters 42 hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 43 hparams["model/params/trainable"] = sum( 44 p.numel() for p in model.parameters() if p.requires_grad 45 ) 46 47 hparams["model/params/non_trainable"] = sum( 48 p.numel() for p in model.parameters() if not p.requires_grad 49 ) 50 # measure model FLOPs 51 test_model = hydra.utils.instantiate(cfg["model"]) 52 num_batches = 1 53 x = torch.randn(num_batches, 3, 224, 224) 54 55 if cfg.get("log_flops", True): 56 # measure the flops caused by the backbone 57 bb_flops = 2 * measure_flops( 58 test_model.net, lambda: test_model.net.get_predictions(x) 59 ) 60 fwd_flops = measure_flops(test_model.net, lambda: test_model.net(x)) 61 hparams["model/gflops/fwd"] = fwd_flops / 1e9 62 hparams["model/gflops/bb_fwd"] = bb_flops / 1e9 63 if hparams["model/params/trainable"] != 0: 64 fwd_and_bwd_flops = measure_flops( 65 test_model.net, 66 lambda: test_model.net(x), 67 lambda out: test_model.criterion(**out)[0], 68 ) 69 hparams["model/gflops/bwd"] = fwd_and_bwd_flops / 1e9 70 else: 71 hparams["model/gflops/bwd"] = 0 72 73 # save explanation resolution 74 if hasattr(model.net, "attention"): 75 hparams["model/attention/resolution"] = model.net.attention.resolution 76 hparams["data"] = cfg["data"] 77 hparams["trainer"] = cfg["trainer"] 78 79 hparams["callbacks"] = cfg.get("callbacks") 80 hparams["extras"] = cfg.get("extras") 81 82 hparams["task_name"] = cfg.get("task_name") 83 hparams["tags"] = cfg.get("tags") 84 hparams["ckpt_path"] = cfg.get("ckpt_path") 85 hparams["seed"] = cfg.get("seed") 86 87 # send hparams to all loggers 88 for logger in trainer.loggers: 89 logger.log_hyperparams(hparams) 90 return hparams["model/params/trainable"] != 0
Controls which config parts are saved by Lightning loggers.
Additionally saves: - Number of model parameters
Parameters
- object_dict: A dictionary containing the following objects:
"cfg": A DictConfig object containing the main config."model": The Lightning model."trainer": The Lightning trainer.
11class RankedLogger(logging.LoggerAdapter): 12 """A multi-GPU-friendly python command line logger.""" 13 14 def __init__( 15 self, 16 name: str = __name__, 17 rank_zero_only: bool = False, 18 extra: Optional[Mapping[str, object]] = None, 19 ) -> None: 20 """Initializes a multi-GPU-friendly python command line logger that 21 logs on all processes with their rank prefixed in the log message. 22 23 :param name: The name of the logger. Default is ``__name__``. 24 :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. 25 :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. 26 """ 27 logger = logging.getLogger(name) 28 super().__init__(logger=logger, extra=extra) 29 self.rank_zero_only = rank_zero_only 30 31 def log( 32 self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs 33 ) -> None: 34 """Delegate a log call to the underlying logger, after prefixing its 35 message with the rank of the process it's being logged from. If 36 `'rank'` is provided, then the log will only occur on that 37 rank/process. 38 39 :param level: The level to log at. Look at `logging.__init__.py` for more information. 40 :param msg: The message to log. 41 :param rank: The rank to log at. 42 :param args: Additional args to pass to the underlying logging function. 43 :param kwargs: Any additional keyword args to pass to the underlying logging function. 44 """ 45 if self.isEnabledFor(level): 46 msg, kwargs = self.process(msg, kwargs) 47 current_rank = getattr(rank_zero_only, "rank", None) 48 if current_rank is None: 49 raise RuntimeError( 50 "The `rank_zero_only.rank` needs to be set before use" 51 ) 52 msg = rank_prefixed_message(msg, current_rank) 53 if self.rank_zero_only: 54 if current_rank == 0: 55 self.logger.log(level, msg, *args, **kwargs) 56 else: 57 if rank is None: 58 self.logger.log(level, msg, *args, **kwargs) 59 elif current_rank == rank: 60 self.logger.log(level, msg, *args, **kwargs)
A multi-GPU-friendly python command line logger.
14 def __init__( 15 self, 16 name: str = __name__, 17 rank_zero_only: bool = False, 18 extra: Optional[Mapping[str, object]] = None, 19 ) -> None: 20 """Initializes a multi-GPU-friendly python command line logger that 21 logs on all processes with their rank prefixed in the log message. 22 23 :param name: The name of the logger. Default is ``__name__``. 24 :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. 25 :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. 26 """ 27 logger = logging.getLogger(name) 28 super().__init__(logger=logger, extra=extra) 29 self.rank_zero_only = rank_zero_only
Initializes a multi-GPU-friendly python command line logger that logs on all processes with their rank prefixed in the log message.
Parameters
- name: The name of the logger. Default is
__name__. - rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is
False. - extra: (Optional) A dict-like object which provides contextual information. See
logging.LoggerAdapter.
31 def log( 32 self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs 33 ) -> None: 34 """Delegate a log call to the underlying logger, after prefixing its 35 message with the rank of the process it's being logged from. If 36 `'rank'` is provided, then the log will only occur on that 37 rank/process. 38 39 :param level: The level to log at. Look at `logging.__init__.py` for more information. 40 :param msg: The message to log. 41 :param rank: The rank to log at. 42 :param args: Additional args to pass to the underlying logging function. 43 :param kwargs: Any additional keyword args to pass to the underlying logging function. 44 """ 45 if self.isEnabledFor(level): 46 msg, kwargs = self.process(msg, kwargs) 47 current_rank = getattr(rank_zero_only, "rank", None) 48 if current_rank is None: 49 raise RuntimeError( 50 "The `rank_zero_only.rank` needs to be set before use" 51 ) 52 msg = rank_prefixed_message(msg, current_rank) 53 if self.rank_zero_only: 54 if current_rank == 0: 55 self.logger.log(level, msg, *args, **kwargs) 56 else: 57 if rank is None: 58 self.logger.log(level, msg, *args, **kwargs) 59 elif current_rank == rank: 60 self.logger.log(level, msg, *args, **kwargs)
Delegate a log call to the underlying logger, after prefixing its
message with the rank of the process it's being logged from. If
'rank' is provided, then the log will only occur on that
rank/process.
Parameters
- level: The level to log at. Look at
logging.__init__.pyfor more information. - msg: The message to log.
- rank: The rank to log at.
- args: Additional args to pass to the underlying logging function.
- kwargs: Any additional keyword args to pass to the underlying logging function.
18@rank_zero_only 19def print_config_tree( 20 cfg: DictConfig, 21 print_order: Sequence[str] = ( 22 "data", 23 "model", 24 "callbacks", 25 "logger", 26 "trainer", 27 "paths", 28 "extras", 29 ), 30 resolve: bool = False, 31 save_to_file: bool = False, 32) -> None: 33 """Prints the contents of a DictConfig as a tree structure using the Rich 34 library. 35 36 :param cfg: A DictConfig composed by Hydra. 37 :param print_order: Determines in what order config components are printed. Default is ``("data", "model", 38 "callbacks", "logger", "trainer", "paths", "extras")``. 39 :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. 40 :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. 41 """ 42 style = "dim" 43 tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 44 45 queue = [] 46 47 # add fields from `print_order` to queue 48 for field in print_order: 49 ( 50 queue.append(field) 51 if field in cfg 52 else log.warning( 53 f"Field '{field}' not found in config. Skipping '{field}' config printing..." 54 ) 55 ) 56 57 # add all the other fields to queue (not specified in `print_order`) 58 for field in cfg: 59 if field not in queue: 60 queue.append(field) 61 62 # generate config tree from queue 63 for field in queue: 64 branch = tree.add(field, style=style, guide_style=style) 65 66 config_group = cfg[field] 67 if isinstance(config_group, DictConfig): 68 branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 69 else: 70 branch_content = str(config_group) 71 72 branch.add(rich.syntax.Syntax(branch_content, "yaml")) 73 74 # print config tree 75 rich.print(tree) 76 77 # save config tree to file 78 if save_to_file: 79 with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 80 rich.print(tree, file=file)
Prints the contents of a DictConfig as a tree structure using the Rich library.
Parameters
- cfg: A DictConfig composed by Hydra.
- print_order: Determines in what order config components are printed. Default is
("data", "model", "callbacks", "logger", "trainer", "paths", "extras"). - resolve: Whether to resolve reference fields of DictConfig. Default is
False. - save_to_file: Whether to export config to the hydra output folder. Default is
False.
106def ax_wrapper(task_func: Callable) -> Callable: 107 """Optional decorator that controls the failure behavior when executing the 108 task function. 109 110 This wrapper can be used to: 111 - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) 112 - save the exception to a `.log` file 113 - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) 114 - etc. (adjust depending on your needs) 115 116 Example: 117 ``` 118 @utils.task_wrapper 119 def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 120 ... 121 return metric_dict, object_dict 122 ``` 123 124 :param task_func: The task function to be wrapped. 125 126 :return: The wrapped task function. 127 """ 128 129 def wrap( 130 cfg: DictConfig, trial: Tuple[int, AxClient], output_dir="./" 131 ) -> Tuple[Dict[str, Any], bool]: 132 metric_dict = {} 133 status = False 134 # execute the task 135 try: 136 metric_dict, status = task_func(cfg=cfg) 137 138 except (KeyboardInterrupt, RuntimeError) as ex: 139 log.exception( 140 f"Prematurely killed at time: {time.strftime('%Y-%m-%d %H:%M:%S')}" 141 ) 142 # save experiment 143 trial[1].log_trial_failure(trial[0]) 144 for i in range(5): 145 try: 146 trial[1].save_to_json_file( 147 f"{output_dir}/killed_experiment.json" 148 ) 149 break 150 except Exception as ex: 151 if i == 4: 152 log.exception( 153 f"Could not save experiment to json file! <{type(ex).__name__}>" 154 ) 155 156 # the above needs some time to save the file, but it is not blocking 157 # so we need to wait a bit before exiting 158 time.sleep(5) 159 raise ex 160 # things to do if exception occurs 161 except Exception as ex: 162 log.exception(f"Exception occurred! type: <{type(ex).__name__}>") 163 # save exception to `.log` file 164 # some hyperparameter combinations might be invalid or cause out-of-memory errors 165 # so when using hparam search plugins like Optuna, you might want to disable 166 # raising the below exception to avoid multirun failure 167 # things to always do after either success or exception 168 finally: 169 # display output dir path in terminal 170 log.info(f"Output dir: {cfg.paths.output_dir}") 171 172 # always close wandb run (even if exception occurs so multirun won't fail) 173 if find_spec("wandb"): # check if wandb is installed 174 import wandb 175 176 if wandb.run: 177 log.info("Closing wandb!") 178 exit_code = 0 if status else 1 179 wandb.finish(exit_code=exit_code) 180 181 return metric_dict, status 182 183 return wrap
Optional decorator that controls the failure behavior when executing the task function.
This wrapper can be used to:
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
- save the exception to a .log file
- mark the run as failed with a dedicated file in the logs/ folder (so we can find and rerun it later)
- etc. (adjust depending on your needs)
Example:
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
...
return metric_dict, object_dict
Parameters
- task_func: The task function to be wrapped.
Returns
The wrapped task function.
15def extras(cfg: DictConfig) -> None: 16 """Applies optional utilities before the task is started. 17 18 Utilities: 19 - Ignoring python warnings 20 - Setting tags from command line 21 - Rich config printing 22 23 :param cfg: A DictConfig object containing the config tree. 24 """ 25 # return if no `extras` config 26 if not cfg.get("extras"): 27 log.warning("Extras config not found! <cfg.extras=null>") 28 return 29 30 # disable python warnings 31 if cfg.extras.get("ignore_warnings"): 32 log.info( 33 "Disabling python warnings! <cfg.extras.ignore_warnings=True>" 34 ) 35 warnings.filterwarnings("ignore") 36 37 # prompt user to input tags from command line if none are provided in the config 38 if cfg.extras.get("enforce_tags"): 39 log.info("Enforcing tags! <cfg.extras.enforce_tags=True>") 40 rich_utils.enforce_tags(cfg, save_to_file=True) 41 42 # pretty print config tree using Rich library 43 if cfg.extras.get("print_config"): 44 log.info( 45 "Printing config tree with Rich! <cfg.extras.print_config=True>" 46 ) 47 rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
Applies optional utilities before the task is started.
Utilities: - Ignoring python warnings - Setting tags from command line - Rich config printing
Parameters
- cfg: A DictConfig object containing the config tree.
186def get_metric_value( 187 metric_dict: Dict[str, Any], metric_name: Optional[str] 188) -> Optional[float]: 189 """Safely retrieves value of the metric logged in LightningModule. 190 191 :param metric_dict: A dict containing metric values. 192 :param metric_name: If provided, the name of the metric to retrieve. 193 :return: If a metric name was provided, the value of the metric. 194 """ 195 if not metric_name: 196 log.info("Metric name is None! Skipping metric value retrieval...") 197 return None 198 199 if metric_name not in metric_dict: 200 raise Exception( 201 f"Metric value not found! <metric_name={metric_name}>\n" 202 "Make sure metric name logged in LightningModule is correct!\n" 203 "Make sure `optimized_metric` name in `hparams_search` config is correct!" 204 ) 205 206 metric_value = metric_dict[metric_name].item() 207 log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") 208 209 return metric_value
Safely retrieves value of the metric logged in LightningModule.
Parameters
- metric_dict: A dict containing metric values.
- metric_name: If provided, the name of the metric to retrieve.
Returns
If a metric name was provided, the value of the metric.
50def task_wrapper(task_func: Callable) -> Callable: 51 """Optional decorator that controls the failure behavior when executing the 52 task function. 53 54 This wrapper can be used to: 55 - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) 56 - save the exception to a `.log` file 57 - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) 58 - etc. (adjust depending on your needs) 59 60 Example: 61 ``` 62 @utils.task_wrapper 63 def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 64 ... 65 return metric_dict, object_dict 66 ``` 67 68 :param task_func: The task function to be wrapped. 69 70 :return: The wrapped task function. 71 """ 72 73 def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 74 # execute the task 75 exit_code = 1 76 try: 77 metric_dict, object_dict = task_func(cfg=cfg) 78 exit_code = 0 79 # things to do if exception occurs 80 except Exception as ex: 81 # save exception to `.log` file 82 log.exception("") 83 # some hyperparameter combinations might be invalid or cause out-of-memory errors 84 # so when using hparam search plugins like Optuna, you might want to disable 85 # raising the below exception to avoid multirun failure 86 raise ex 87 88 # things to always do after either success or exception 89 finally: 90 # display output dir path in terminal 91 log.info(f"Output dir: {cfg.paths.output_dir}") 92 93 # always close wandb run (even if exception occurs so multirun won't fail) 94 if find_spec("wandb"): # check if wandb is installed 95 import wandb 96 97 if wandb.run: 98 log.info("Closing wandb!") 99 wandb.finish(exit_code=exit_code) 100 101 return metric_dict, object_dict 102 103 return wrap
Optional decorator that controls the failure behavior when executing the task function.
This wrapper can be used to:
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
- save the exception to a .log file
- mark the run as failed with a dedicated file in the logs/ folder (so we can find and rerun it later)
- etc. (adjust depending on your needs)
Example:
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
...
return metric_dict, object_dict
Parameters
- task_func: The task function to be wrapped.
Returns
The wrapped task function.
31class MemoryFormat(Callback): 32 """The `MemoryFormat` callback changes the model memory format to 33 `torch.channels_last` before training starts and returns the original when 34 it ends. 35 36 <https://\\pytorch.org/tutorials/intermediate/memory_format_tutorial.html>`_. 37 Setting the memory format channels_last usually improves GPU utilization. 38 Runs on setup, so it can set the memory format before the model is DDP wrapped. 39 """ 40 41 def __init__( 42 self, 43 memory_format: torch.memory_format = torch.channels_last, 44 convert_input: bool = False, 45 ): 46 self.memory_format = memory_format 47 self.convert_input = convert_input 48 49 def setup( 50 self, 51 trainer: "pl.Trainer", 52 pl_module: "pl.LightningModule", 53 stage: Optional[str] = None, 54 ) -> None: 55 """ 56 Sets up the memory format for the given PyTorch Lightning module during the training process. 57 Args: 58 trainer (pl.Trainer): The PyTorch Lightning Trainer instance. 59 pl_module (pl.LightningModule): The PyTorch Lightning module to be configured. 60 stage (Optional[str]): The stage of the training process. Defaults to None. 61 Notes: 62 - If the specified memory format (e.g., `torch.channels_last` or `torch.channels_last_3d`) 63 does not benefit any layers in the model, a warning will be issued. 64 - The `pl_module` is moved to the specified memory format. 65 """ 66 67 if self.memory_format in ( 68 torch.channels_last, 69 torch.channels_last_3d, 70 ) and not self.has_layer_benefiting_from_channels_last(pl_module): 71 rank_zero_warn( 72 f"model does not have any layers benefiting from {self.memory_format} format", 73 category=RuntimeWarning, 74 ) 75 76 pl_module.to(memory_format=self.memory_format) 77 78 def teardown( 79 self, 80 trainer: "pl.Trainer", 81 pl_module: "pl.LightningModule", 82 stage: Optional[str] = None, 83 ) -> None: 84 """Handles the teardown process for the PyTorch Lightning module by 85 converting the module's memory format to contiguous format. 86 87 Args: 88 trainer (pl.Trainer): The PyTorch Lightning Trainer instance. 89 pl_module (pl.LightningModule): The PyTorch Lightning module being trained. 90 stage (Optional[str]): The stage of the training process (e.g., 'fit', 'test', etc.). 91 Defaults to None. 92 Returns: 93 None 94 """ 95 96 pl_module.to(memory_format=torch.contiguous_format) 97 98 def on_train_batch_start( 99 self, 100 trainer: "pl.Trainer", 101 pl_module: "pl.LightningModule", 102 batch: Any, 103 batch_idx: int, 104 ) -> None: 105 """Hook that is called at the start of each training batch. 106 107 This method is used to optionally convert the input batch's tensors to a specified memory format. 108 If the `convert_input` flag is set to `False`, the method exits early without performing any conversion. 109 If the batch is not a `MutableSequence`, a warning is issued, and no conversion is performed. 110 111 Args: 112 trainer (pl.Trainer): The PyTorch Lightning Trainer instance. 113 pl_module (pl.LightningModule): The LightningModule being trained. 114 batch (Any): The input batch of data. 115 batch_idx (int): The index of the current batch. 116 117 Returns: 118 None 119 """ 120 if not self.convert_input: 121 return 122 123 if not isinstance(batch, MutableSequence): 124 rank_zero_warn( 125 f"batch is not a MutableSequence, cannot convert input to {self.memory_format}", 126 category=RuntimeWarning, 127 ) 128 return 129 130 for i, item in enumerate(batch): 131 if isinstance(item, torch.Tensor) and item.ndim == 4: 132 batch[i] = item.to(memory_format=self.memory_format) 133 134 benefitial_layers = ( 135 torch.nn.BatchNorm2d, 136 torch.nn.BatchNorm3d, 137 torch.nn.Conv2d, 138 torch.nn.Conv3d, 139 ) 140 141 def has_layer_benefiting_from_channels_last( 142 self, model: torch.nn.Module 143 ) -> bool: 144 """Determines if the given model contains any layer that would benefit 145 from using the "channels_last" memory format. 146 147 Args: 148 model (torch.nn.Module): The PyTorch model to inspect. 149 150 Returns: 151 bool: True if at least one layer in the model is an instance of 152 the types specified in `self.benefitial_layers`, otherwise False. 153 """ 154 return any( 155 isinstance(layer, self.benefitial_layers) 156 for layer in model.modules() 157 )
The MemoryFormat callback changes the model memory format to
`torch.channels_last[ before training starts and returns the original when
it ends.
](https://\pytorch.org/tutorials/intermediate/memory_format_tutorial.html). Setting the memory format channels_last usually improves GPU utilization. Runs on setup, so it can set the memory format before the model is DDP wrapped.
49 def setup( 50 self, 51 trainer: "pl.Trainer", 52 pl_module: "pl.LightningModule", 53 stage: Optional[str] = None, 54 ) -> None: 55 """ 56 Sets up the memory format for the given PyTorch Lightning module during the training process. 57 Args: 58 trainer (pl.Trainer): The PyTorch Lightning Trainer instance. 59 pl_module (pl.LightningModule): The PyTorch Lightning module to be configured. 60 stage (Optional[str]): The stage of the training process. Defaults to None. 61 Notes: 62 - If the specified memory format (e.g., `torch.channels_last` or `torch.channels_last_3d`) 63 does not benefit any layers in the model, a warning will be issued. 64 - The `pl_module` is moved to the specified memory format. 65 """ 66 67 if self.memory_format in ( 68 torch.channels_last, 69 torch.channels_last_3d, 70 ) and not self.has_layer_benefiting_from_channels_last(pl_module): 71 rank_zero_warn( 72 f"model does not have any layers benefiting from {self.memory_format} format", 73 category=RuntimeWarning, 74 ) 75 76 pl_module.to(memory_format=self.memory_format)
Sets up the memory format for the given PyTorch Lightning module during the training process.
Args:
trainer (pl.Trainer): The PyTorch Lightning Trainer instance.
pl_module (pl.LightningModule): The PyTorch Lightning module to be configured.
stage (Optional[str]): The stage of the training process. Defaults to None.
Notes:
- If the specified memory format (e.g., torch.channels_last or torch.channels_last_3d)
does not benefit any layers in the model, a warning will be issued.
- The pl_module is moved to the specified memory format.
78 def teardown( 79 self, 80 trainer: "pl.Trainer", 81 pl_module: "pl.LightningModule", 82 stage: Optional[str] = None, 83 ) -> None: 84 """Handles the teardown process for the PyTorch Lightning module by 85 converting the module's memory format to contiguous format. 86 87 Args: 88 trainer (pl.Trainer): The PyTorch Lightning Trainer instance. 89 pl_module (pl.LightningModule): The PyTorch Lightning module being trained. 90 stage (Optional[str]): The stage of the training process (e.g., 'fit', 'test', etc.). 91 Defaults to None. 92 Returns: 93 None 94 """ 95 96 pl_module.to(memory_format=torch.contiguous_format)
Handles the teardown process for the PyTorch Lightning module by converting the module's memory format to contiguous format.
Args: trainer (pl.Trainer): The PyTorch Lightning Trainer instance. pl_module (pl.LightningModule): The PyTorch Lightning module being trained. stage (Optional[str]): The stage of the training process (e.g., 'fit', 'test', etc.). Defaults to None. Returns: None
98 def on_train_batch_start( 99 self, 100 trainer: "pl.Trainer", 101 pl_module: "pl.LightningModule", 102 batch: Any, 103 batch_idx: int, 104 ) -> None: 105 """Hook that is called at the start of each training batch. 106 107 This method is used to optionally convert the input batch's tensors to a specified memory format. 108 If the `convert_input` flag is set to `False`, the method exits early without performing any conversion. 109 If the batch is not a `MutableSequence`, a warning is issued, and no conversion is performed. 110 111 Args: 112 trainer (pl.Trainer): The PyTorch Lightning Trainer instance. 113 pl_module (pl.LightningModule): The LightningModule being trained. 114 batch (Any): The input batch of data. 115 batch_idx (int): The index of the current batch. 116 117 Returns: 118 None 119 """ 120 if not self.convert_input: 121 return 122 123 if not isinstance(batch, MutableSequence): 124 rank_zero_warn( 125 f"batch is not a MutableSequence, cannot convert input to {self.memory_format}", 126 category=RuntimeWarning, 127 ) 128 return 129 130 for i, item in enumerate(batch): 131 if isinstance(item, torch.Tensor) and item.ndim == 4: 132 batch[i] = item.to(memory_format=self.memory_format)
Hook that is called at the start of each training batch.
This method is used to optionally convert the input batch's tensors to a specified memory format.
If the convert_input flag is set to False, the method exits early without performing any conversion.
If the batch is not a MutableSequence, a warning is issued, and no conversion is performed.
Args: trainer (pl.Trainer): The PyTorch Lightning Trainer instance. pl_module (pl.LightningModule): The LightningModule being trained. batch (Any): The input batch of data. batch_idx (int): The index of the current batch.
Returns: None
141 def has_layer_benefiting_from_channels_last( 142 self, model: torch.nn.Module 143 ) -> bool: 144 """Determines if the given model contains any layer that would benefit 145 from using the "channels_last" memory format. 146 147 Args: 148 model (torch.nn.Module): The PyTorch model to inspect. 149 150 Returns: 151 bool: True if at least one layer in the model is an instance of 152 the types specified in `self.benefitial_layers`, otherwise False. 153 """ 154 return any( 155 isinstance(layer, self.benefitial_layers) 156 for layer in model.modules() 157 )
Determines if the given model contains any layer that would benefit from using the "channels_last" memory format.
Args: model (torch.nn.Module): The PyTorch model to inspect.
Returns:
bool: True if at least one layer in the model is an instance of
the types specified in self.benefitial_layers, otherwise False.