ptame.utils

 1from .instantiators import instantiate_callbacks, instantiate_loggers
 2from .logging_utils import log_hyperparameters
 3from .memory_format import MemoryFormat
 4from .pylogger import RankedLogger
 5from .rich_utils import enforce_tags, print_config_tree
 6from .utils import ax_wrapper, extras, get_metric_value, task_wrapper
 7
 8__all__ = [
 9    "instantiate_callbacks",
10    "instantiate_loggers",
11    "log_hyperparameters",
12    "RankedLogger",
13    "enforce_tags",
14    "print_config_tree",
15    "ax_wrapper",
16    "extras",
17    "get_metric_value",
18    "task_wrapper",
19    "MemoryFormat",
20]
def instantiate_callbacks( callbacks_cfg: omegaconf.dictconfig.DictConfig) -> List[lightning.pytorch.callbacks.callback.Callback]:
14def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
15    """Instantiates callbacks from config.
16
17    :param callbacks_cfg: A DictConfig object containing callback
18        configurations.
19    :return: A list of instantiated callbacks.
20    """
21    callbacks: List[Callback] = []
22
23    if not callbacks_cfg:
24        log.warning("No callback configs found! Skipping..")
25        return callbacks
26
27    if not isinstance(callbacks_cfg, DictConfig):
28        raise TypeError("Callbacks config must be a DictConfig!")
29
30    for _, cb_conf in callbacks_cfg.items():
31        if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
32            log.info(f"Instantiating callback <{cb_conf._target_}>")
33            callbacks.append(hydra.utils.instantiate(cb_conf))
34
35    return callbacks

Instantiates callbacks from config.

Parameters
  • callbacks_cfg: A DictConfig object containing callback configurations.
Returns

A list of instantiated callbacks.

def instantiate_loggers( logger_cfg: omegaconf.dictconfig.DictConfig) -> List[lightning.pytorch.loggers.logger.Logger]:
38def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
39    """Instantiates loggers from config.
40
41    :param logger_cfg: A DictConfig object containing logger configurations.
42    :return: A list of instantiated loggers.
43    """
44    logger: List[Logger] = []
45
46    if not logger_cfg:
47        log.warning("No logger configs found! Skipping...")
48        return logger
49
50    if not isinstance(logger_cfg, DictConfig):
51        raise TypeError("Logger config must be a DictConfig!")
52
53    for _, lg_conf in logger_cfg.items():
54        if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
55            log.info(f"Instantiating logger <{lg_conf._target_}>")
56            logger.append(hydra.utils.instantiate(lg_conf))
57
58    return logger

Instantiates loggers from config.

Parameters
  • logger_cfg: A DictConfig object containing logger configurations.
Returns

A list of instantiated loggers.

@rank_zero_only
def log_hyperparameters(object_dict: dict[str, typing.Any]) -> bool:
15@rank_zero_only
16def log_hyperparameters(object_dict: dict[str, Any]) -> bool:
17    """Controls which config parts are saved by Lightning loggers.
18
19    Additionally saves:
20        - Number of model parameters
21
22    :param object_dict: A dictionary containing the following objects:
23        - `"cfg"`: A DictConfig object containing the main config.
24        - `"model"`: The Lightning model.
25        - `"trainer"`: The Lightning trainer.
26    """
27    hparams = {}
28
29    cfg = OmegaConf.to_container(object_dict["cfg"], resolve=True)
30    model = object_dict["model"]
31    trainer = object_dict["trainer"]
32
33    try:
34        trainer.loggers
35    except AttributeError:
36        log.warning("Logger not found! Skipping hyperparameter logging...")
37        return
38
39    hparams["model"] = cfg["model"]
40
41    # save number of model parameters
42    hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
43    hparams["model/params/trainable"] = sum(
44        p.numel() for p in model.parameters() if p.requires_grad
45    )
46
47    hparams["model/params/non_trainable"] = sum(
48        p.numel() for p in model.parameters() if not p.requires_grad
49    )
50    # measure model FLOPs
51    test_model = hydra.utils.instantiate(cfg["model"])
52    num_batches = 1
53    x = torch.randn(num_batches, 3, 224, 224)
54
55    if cfg.get("log_flops", True):
56        # measure the flops caused by the backbone
57        bb_flops = 2 * measure_flops(
58            test_model.net, lambda: test_model.net.get_predictions(x)
59        )
60        fwd_flops = measure_flops(test_model.net, lambda: test_model.net(x))
61        hparams["model/gflops/fwd"] = fwd_flops / 1e9
62        hparams["model/gflops/bb_fwd"] = bb_flops / 1e9
63        if hparams["model/params/trainable"] != 0:
64            fwd_and_bwd_flops = measure_flops(
65                test_model.net,
66                lambda: test_model.net(x),
67                lambda out: test_model.criterion(**out)[0],
68            )
69            hparams["model/gflops/bwd"] = fwd_and_bwd_flops / 1e9
70        else:
71            hparams["model/gflops/bwd"] = 0
72
73    # save explanation resolution
74    if hasattr(model.net, "attention"):
75        hparams["model/attention/resolution"] = model.net.attention.resolution
76    hparams["data"] = cfg["data"]
77    hparams["trainer"] = cfg["trainer"]
78
79    hparams["callbacks"] = cfg.get("callbacks")
80    hparams["extras"] = cfg.get("extras")
81
82    hparams["task_name"] = cfg.get("task_name")
83    hparams["tags"] = cfg.get("tags")
84    hparams["ckpt_path"] = cfg.get("ckpt_path")
85    hparams["seed"] = cfg.get("seed")
86
87    # send hparams to all loggers
88    for logger in trainer.loggers:
89        logger.log_hyperparams(hparams)
90    return hparams["model/params/trainable"] != 0

Controls which config parts are saved by Lightning loggers.

Additionally saves: - Number of model parameters

Parameters
  • object_dict: A dictionary containing the following objects:
    • "cfg": A DictConfig object containing the main config.
    • "model": The Lightning model.
    • "trainer": The Lightning trainer.
class RankedLogger(logging.LoggerAdapter):
11class RankedLogger(logging.LoggerAdapter):
12    """A multi-GPU-friendly python command line logger."""
13
14    def __init__(
15        self,
16        name: str = __name__,
17        rank_zero_only: bool = False,
18        extra: Optional[Mapping[str, object]] = None,
19    ) -> None:
20        """Initializes a multi-GPU-friendly python command line logger that
21        logs on all processes with their rank prefixed in the log message.
22
23        :param name: The name of the logger. Default is ``__name__``.
24        :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
25        :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
26        """
27        logger = logging.getLogger(name)
28        super().__init__(logger=logger, extra=extra)
29        self.rank_zero_only = rank_zero_only
30
31    def log(
32        self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
33    ) -> None:
34        """Delegate a log call to the underlying logger, after prefixing its
35        message with the rank of the process it's being logged from. If
36        `'rank'` is provided, then the log will only occur on that
37        rank/process.
38
39        :param level: The level to log at. Look at `logging.__init__.py` for more information.
40        :param msg: The message to log.
41        :param rank: The rank to log at.
42        :param args: Additional args to pass to the underlying logging function.
43        :param kwargs: Any additional keyword args to pass to the underlying logging function.
44        """
45        if self.isEnabledFor(level):
46            msg, kwargs = self.process(msg, kwargs)
47            current_rank = getattr(rank_zero_only, "rank", None)
48            if current_rank is None:
49                raise RuntimeError(
50                    "The `rank_zero_only.rank` needs to be set before use"
51                )
52            msg = rank_prefixed_message(msg, current_rank)
53            if self.rank_zero_only:
54                if current_rank == 0:
55                    self.logger.log(level, msg, *args, **kwargs)
56            else:
57                if rank is None:
58                    self.logger.log(level, msg, *args, **kwargs)
59                elif current_rank == rank:
60                    self.logger.log(level, msg, *args, **kwargs)

A multi-GPU-friendly python command line logger.

RankedLogger( name: str = 'ptame.utils.pylogger', rank_zero_only: bool = False, extra: Optional[Mapping[str, object]] = None)
14    def __init__(
15        self,
16        name: str = __name__,
17        rank_zero_only: bool = False,
18        extra: Optional[Mapping[str, object]] = None,
19    ) -> None:
20        """Initializes a multi-GPU-friendly python command line logger that
21        logs on all processes with their rank prefixed in the log message.
22
23        :param name: The name of the logger. Default is ``__name__``.
24        :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
25        :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
26        """
27        logger = logging.getLogger(name)
28        super().__init__(logger=logger, extra=extra)
29        self.rank_zero_only = rank_zero_only

Initializes a multi-GPU-friendly python command line logger that logs on all processes with their rank prefixed in the log message.

Parameters
  • name: The name of the logger. Default is __name__.
  • rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is False.
  • extra: (Optional) A dict-like object which provides contextual information. See logging.LoggerAdapter.
rank_zero_only
def log( self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:
31    def log(
32        self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
33    ) -> None:
34        """Delegate a log call to the underlying logger, after prefixing its
35        message with the rank of the process it's being logged from. If
36        `'rank'` is provided, then the log will only occur on that
37        rank/process.
38
39        :param level: The level to log at. Look at `logging.__init__.py` for more information.
40        :param msg: The message to log.
41        :param rank: The rank to log at.
42        :param args: Additional args to pass to the underlying logging function.
43        :param kwargs: Any additional keyword args to pass to the underlying logging function.
44        """
45        if self.isEnabledFor(level):
46            msg, kwargs = self.process(msg, kwargs)
47            current_rank = getattr(rank_zero_only, "rank", None)
48            if current_rank is None:
49                raise RuntimeError(
50                    "The `rank_zero_only.rank` needs to be set before use"
51                )
52            msg = rank_prefixed_message(msg, current_rank)
53            if self.rank_zero_only:
54                if current_rank == 0:
55                    self.logger.log(level, msg, *args, **kwargs)
56            else:
57                if rank is None:
58                    self.logger.log(level, msg, *args, **kwargs)
59                elif current_rank == rank:
60                    self.logger.log(level, msg, *args, **kwargs)

Delegate a log call to the underlying logger, after prefixing its message with the rank of the process it's being logged from. If 'rank' is provided, then the log will only occur on that rank/process.

Parameters
  • level: The level to log at. Look at logging.__init__.py for more information.
  • msg: The message to log.
  • rank: The rank to log at.
  • args: Additional args to pass to the underlying logging function.
  • kwargs: Any additional keyword args to pass to the underlying logging function.
@rank_zero_only
def enforce_tags(cfg: omegaconf.dictconfig.DictConfig, save_to_file: bool = False) -> None:
 83@rank_zero_only
 84def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
 85    """Prompts user to input tags from command line if no tags are provided in
 86    config.
 87
 88    :param cfg: A DictConfig composed by Hydra.
 89    :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
 90    """
 91    if not cfg.get("tags"):
 92        if "id" in HydraConfig().cfg.hydra.job:
 93            raise ValueError("Specify tags before launching a multirun!")
 94
 95        log.warning(
 96            "No tags provided in config. Prompting user to input tags..."
 97        )
 98        tags = Prompt.ask(
 99            "Enter a list of comma separated tags", default="dev"
100        )
101        tags = [t.strip() for t in tags.split(",") if t != ""]
102
103        with open_dict(cfg):
104            cfg.tags = tags
105
106        log.info(f"Tags: {cfg.tags}")
107
108    if save_to_file:
109        with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
110            rich.print(cfg.tags, file=file)

Prompts user to input tags from command line if no tags are provided in config.

Parameters
  • cfg: A DictConfig composed by Hydra.
  • save_to_file: Whether to export tags to the hydra output folder. Default is False.
def ax_wrapper(task_func: Callable) -> Callable:
106def ax_wrapper(task_func: Callable) -> Callable:
107    """Optional decorator that controls the failure behavior when executing the
108    task function.
109
110    This wrapper can be used to:
111        - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
112        - save the exception to a `.log` file
113        - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
114        - etc. (adjust depending on your needs)
115
116    Example:
117    ```
118    @utils.task_wrapper
119    def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
120        ...
121        return metric_dict, object_dict
122    ```
123
124    :param task_func: The task function to be wrapped.
125
126    :return: The wrapped task function.
127    """
128
129    def wrap(
130        cfg: DictConfig, trial: Tuple[int, AxClient], output_dir="./"
131    ) -> Tuple[Dict[str, Any], bool]:
132        metric_dict = {}
133        status = False
134        # execute the task
135        try:
136            metric_dict, status = task_func(cfg=cfg)
137
138        except (KeyboardInterrupt, RuntimeError) as ex:
139            log.exception(
140                f"Prematurely killed at time: {time.strftime('%Y-%m-%d %H:%M:%S')}"
141            )
142            # save experiment
143            trial[1].log_trial_failure(trial[0])
144            for i in range(5):
145                try:
146                    trial[1].save_to_json_file(
147                        f"{output_dir}/killed_experiment.json"
148                    )
149                    break
150                except Exception as ex:
151                    if i == 4:
152                        log.exception(
153                            f"Could not save experiment to json file! <{type(ex).__name__}>"
154                        )
155
156            # the above needs some time to save the file, but it is not blocking
157            # so we need to wait a bit before exiting
158            time.sleep(5)
159            raise ex
160        # things to do if exception occurs
161        except Exception as ex:
162            log.exception(f"Exception occurred! type: <{type(ex).__name__}>")
163            # save exception to `.log` file
164            # some hyperparameter combinations might be invalid or cause out-of-memory errors
165            # so when using hparam search plugins like Optuna, you might want to disable
166            # raising the below exception to avoid multirun failure
167        # things to always do after either success or exception
168        finally:
169            # display output dir path in terminal
170            log.info(f"Output dir: {cfg.paths.output_dir}")
171
172            # always close wandb run (even if exception occurs so multirun won't fail)
173            if find_spec("wandb"):  # check if wandb is installed
174                import wandb
175
176                if wandb.run:
177                    log.info("Closing wandb!")
178                    exit_code = 0 if status else 1
179                    wandb.finish(exit_code=exit_code)
180
181        return metric_dict, status
182
183    return wrap

Optional decorator that controls the failure behavior when executing the task function.

This wrapper can be used to: - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) - save the exception to a .log file - mark the run as failed with a dedicated file in the logs/ folder (so we can find and rerun it later) - etc. (adjust depending on your needs)

Example:

@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    ...
    return metric_dict, object_dict
Parameters
  • task_func: The task function to be wrapped.
Returns

The wrapped task function.

def extras(cfg: omegaconf.dictconfig.DictConfig) -> None:
15def extras(cfg: DictConfig) -> None:
16    """Applies optional utilities before the task is started.
17
18    Utilities:
19        - Ignoring python warnings
20        - Setting tags from command line
21        - Rich config printing
22
23    :param cfg: A DictConfig object containing the config tree.
24    """
25    # return if no `extras` config
26    if not cfg.get("extras"):
27        log.warning("Extras config not found! <cfg.extras=null>")
28        return
29
30    # disable python warnings
31    if cfg.extras.get("ignore_warnings"):
32        log.info(
33            "Disabling python warnings! <cfg.extras.ignore_warnings=True>"
34        )
35        warnings.filterwarnings("ignore")
36
37    # prompt user to input tags from command line if none are provided in the config
38    if cfg.extras.get("enforce_tags"):
39        log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
40        rich_utils.enforce_tags(cfg, save_to_file=True)
41
42    # pretty print config tree using Rich library
43    if cfg.extras.get("print_config"):
44        log.info(
45            "Printing config tree with Rich! <cfg.extras.print_config=True>"
46        )
47        rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)

Applies optional utilities before the task is started.

Utilities: - Ignoring python warnings - Setting tags from command line - Rich config printing

Parameters
  • cfg: A DictConfig object containing the config tree.
def get_metric_value( metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]:
186def get_metric_value(
187    metric_dict: Dict[str, Any], metric_name: Optional[str]
188) -> Optional[float]:
189    """Safely retrieves value of the metric logged in LightningModule.
190
191    :param metric_dict: A dict containing metric values.
192    :param metric_name: If provided, the name of the metric to retrieve.
193    :return: If a metric name was provided, the value of the metric.
194    """
195    if not metric_name:
196        log.info("Metric name is None! Skipping metric value retrieval...")
197        return None
198
199    if metric_name not in metric_dict:
200        raise Exception(
201            f"Metric value not found! <metric_name={metric_name}>\n"
202            "Make sure metric name logged in LightningModule is correct!\n"
203            "Make sure `optimized_metric` name in `hparams_search` config is correct!"
204        )
205
206    metric_value = metric_dict[metric_name].item()
207    log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
208
209    return metric_value

Safely retrieves value of the metric logged in LightningModule.

Parameters
  • metric_dict: A dict containing metric values.
  • metric_name: If provided, the name of the metric to retrieve.
Returns

If a metric name was provided, the value of the metric.

def task_wrapper(task_func: Callable) -> Callable:
 50def task_wrapper(task_func: Callable) -> Callable:
 51    """Optional decorator that controls the failure behavior when executing the
 52    task function.
 53
 54    This wrapper can be used to:
 55        - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
 56        - save the exception to a `.log` file
 57        - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
 58        - etc. (adjust depending on your needs)
 59
 60    Example:
 61    ```
 62    @utils.task_wrapper
 63    def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
 64        ...
 65        return metric_dict, object_dict
 66    ```
 67
 68    :param task_func: The task function to be wrapped.
 69
 70    :return: The wrapped task function.
 71    """
 72
 73    def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
 74        # execute the task
 75        exit_code = 1
 76        try:
 77            metric_dict, object_dict = task_func(cfg=cfg)
 78            exit_code = 0
 79        # things to do if exception occurs
 80        except Exception as ex:
 81            # save exception to `.log` file
 82            log.exception("")
 83            # some hyperparameter combinations might be invalid or cause out-of-memory errors
 84            # so when using hparam search plugins like Optuna, you might want to disable
 85            # raising the below exception to avoid multirun failure
 86            raise ex
 87
 88        # things to always do after either success or exception
 89        finally:
 90            # display output dir path in terminal
 91            log.info(f"Output dir: {cfg.paths.output_dir}")
 92
 93            # always close wandb run (even if exception occurs so multirun won't fail)
 94            if find_spec("wandb"):  # check if wandb is installed
 95                import wandb
 96
 97                if wandb.run:
 98                    log.info("Closing wandb!")
 99                    wandb.finish(exit_code=exit_code)
100
101        return metric_dict, object_dict
102
103    return wrap

Optional decorator that controls the failure behavior when executing the task function.

This wrapper can be used to: - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) - save the exception to a .log file - mark the run as failed with a dedicated file in the logs/ folder (so we can find and rerun it later) - etc. (adjust depending on your needs)

Example:

@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    ...
    return metric_dict, object_dict
Parameters
  • task_func: The task function to be wrapped.
Returns

The wrapped task function.

class MemoryFormat(lightning.pytorch.callbacks.callback.Callback):
 31class MemoryFormat(Callback):
 32    """The `MemoryFormat` callback changes the model memory format to
 33    `torch.channels_last` before training starts and returns the original when
 34    it ends.
 35
 36    <https://\\pytorch.org/tutorials/intermediate/memory_format_tutorial.html>`_.
 37    Setting the memory format channels_last usually improves GPU utilization.
 38    Runs on setup, so it can set the memory format before the model is DDP wrapped.
 39    """
 40
 41    def __init__(
 42        self,
 43        memory_format: torch.memory_format = torch.channels_last,
 44        convert_input: bool = False,
 45    ):
 46        self.memory_format = memory_format
 47        self.convert_input = convert_input
 48
 49    def setup(
 50        self,
 51        trainer: "pl.Trainer",
 52        pl_module: "pl.LightningModule",
 53        stage: Optional[str] = None,
 54    ) -> None:
 55        """
 56        Sets up the memory format for the given PyTorch Lightning module during the training process.
 57        Args:
 58            trainer (pl.Trainer): The PyTorch Lightning Trainer instance.
 59            pl_module (pl.LightningModule): The PyTorch Lightning module to be configured.
 60            stage (Optional[str]): The stage of the training process. Defaults to None.
 61        Notes:
 62            - If the specified memory format (e.g., `torch.channels_last` or `torch.channels_last_3d`)
 63              does not benefit any layers in the model, a warning will be issued.
 64            - The `pl_module` is moved to the specified memory format.
 65        """
 66
 67        if self.memory_format in (
 68            torch.channels_last,
 69            torch.channels_last_3d,
 70        ) and not self.has_layer_benefiting_from_channels_last(pl_module):
 71            rank_zero_warn(
 72                f"model does not have any layers benefiting from {self.memory_format} format",
 73                category=RuntimeWarning,
 74            )
 75
 76        pl_module.to(memory_format=self.memory_format)
 77
 78    def teardown(
 79        self,
 80        trainer: "pl.Trainer",
 81        pl_module: "pl.LightningModule",
 82        stage: Optional[str] = None,
 83    ) -> None:
 84        """Handles the teardown process for the PyTorch Lightning module by
 85        converting the module's memory format to contiguous format.
 86
 87        Args:
 88            trainer (pl.Trainer): The PyTorch Lightning Trainer instance.
 89            pl_module (pl.LightningModule): The PyTorch Lightning module being trained.
 90            stage (Optional[str]): The stage of the training process (e.g., 'fit', 'test', etc.).
 91                Defaults to None.
 92        Returns:
 93            None
 94        """
 95
 96        pl_module.to(memory_format=torch.contiguous_format)
 97
 98    def on_train_batch_start(
 99        self,
100        trainer: "pl.Trainer",
101        pl_module: "pl.LightningModule",
102        batch: Any,
103        batch_idx: int,
104    ) -> None:
105        """Hook that is called at the start of each training batch.
106
107        This method is used to optionally convert the input batch's tensors to a specified memory format.
108        If the `convert_input` flag is set to `False`, the method exits early without performing any conversion.
109        If the batch is not a `MutableSequence`, a warning is issued, and no conversion is performed.
110
111        Args:
112            trainer (pl.Trainer): The PyTorch Lightning Trainer instance.
113            pl_module (pl.LightningModule): The LightningModule being trained.
114            batch (Any): The input batch of data.
115            batch_idx (int): The index of the current batch.
116
117        Returns:
118            None
119        """
120        if not self.convert_input:
121            return
122
123        if not isinstance(batch, MutableSequence):
124            rank_zero_warn(
125                f"batch is not a MutableSequence, cannot convert input to {self.memory_format}",
126                category=RuntimeWarning,
127            )
128            return
129
130        for i, item in enumerate(batch):
131            if isinstance(item, torch.Tensor) and item.ndim == 4:
132                batch[i] = item.to(memory_format=self.memory_format)
133
134    benefitial_layers = (
135        torch.nn.BatchNorm2d,
136        torch.nn.BatchNorm3d,
137        torch.nn.Conv2d,
138        torch.nn.Conv3d,
139    )
140
141    def has_layer_benefiting_from_channels_last(
142        self, model: torch.nn.Module
143    ) -> bool:
144        """Determines if the given model contains any layer that would benefit
145        from using the "channels_last" memory format.
146
147        Args:
148            model (torch.nn.Module): The PyTorch model to inspect.
149
150        Returns:
151            bool: True if at least one layer in the model is an instance of
152            the types specified in `self.benefitial_layers`, otherwise False.
153        """
154        return any(
155            isinstance(layer, self.benefitial_layers)
156            for layer in model.modules()
157        )

The MemoryFormat callback changes the model memory format to `torch.channels_last[ before training starts and returns the original when it ends.

](https://\pytorch.org/tutorials/intermediate/memory_format_tutorial.html). Setting the memory format channels_last usually improves GPU utilization. Runs on setup, so it can set the memory format before the model is DDP wrapped.

MemoryFormat( memory_format: torch.memory_format = torch.channels_last, convert_input: bool = False)
41    def __init__(
42        self,
43        memory_format: torch.memory_format = torch.channels_last,
44        convert_input: bool = False,
45    ):
46        self.memory_format = memory_format
47        self.convert_input = convert_input
memory_format
convert_input
def setup( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: lightning.pytorch.core.module.LightningModule, stage: Optional[str] = None) -> None:
49    def setup(
50        self,
51        trainer: "pl.Trainer",
52        pl_module: "pl.LightningModule",
53        stage: Optional[str] = None,
54    ) -> None:
55        """
56        Sets up the memory format for the given PyTorch Lightning module during the training process.
57        Args:
58            trainer (pl.Trainer): The PyTorch Lightning Trainer instance.
59            pl_module (pl.LightningModule): The PyTorch Lightning module to be configured.
60            stage (Optional[str]): The stage of the training process. Defaults to None.
61        Notes:
62            - If the specified memory format (e.g., `torch.channels_last` or `torch.channels_last_3d`)
63              does not benefit any layers in the model, a warning will be issued.
64            - The `pl_module` is moved to the specified memory format.
65        """
66
67        if self.memory_format in (
68            torch.channels_last,
69            torch.channels_last_3d,
70        ) and not self.has_layer_benefiting_from_channels_last(pl_module):
71            rank_zero_warn(
72                f"model does not have any layers benefiting from {self.memory_format} format",
73                category=RuntimeWarning,
74            )
75
76        pl_module.to(memory_format=self.memory_format)

Sets up the memory format for the given PyTorch Lightning module during the training process. Args: trainer (pl.Trainer): The PyTorch Lightning Trainer instance. pl_module (pl.LightningModule): The PyTorch Lightning module to be configured. stage (Optional[str]): The stage of the training process. Defaults to None. Notes: - If the specified memory format (e.g., torch.channels_last or torch.channels_last_3d) does not benefit any layers in the model, a warning will be issued. - The pl_module is moved to the specified memory format.

def teardown( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: lightning.pytorch.core.module.LightningModule, stage: Optional[str] = None) -> None:
78    def teardown(
79        self,
80        trainer: "pl.Trainer",
81        pl_module: "pl.LightningModule",
82        stage: Optional[str] = None,
83    ) -> None:
84        """Handles the teardown process for the PyTorch Lightning module by
85        converting the module's memory format to contiguous format.
86
87        Args:
88            trainer (pl.Trainer): The PyTorch Lightning Trainer instance.
89            pl_module (pl.LightningModule): The PyTorch Lightning module being trained.
90            stage (Optional[str]): The stage of the training process (e.g., 'fit', 'test', etc.).
91                Defaults to None.
92        Returns:
93            None
94        """
95
96        pl_module.to(memory_format=torch.contiguous_format)

Handles the teardown process for the PyTorch Lightning module by converting the module's memory format to contiguous format.

Args: trainer (pl.Trainer): The PyTorch Lightning Trainer instance. pl_module (pl.LightningModule): The PyTorch Lightning module being trained. stage (Optional[str]): The stage of the training process (e.g., 'fit', 'test', etc.). Defaults to None. Returns: None

def on_train_batch_start( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: lightning.pytorch.core.module.LightningModule, batch: Any, batch_idx: int) -> None:
 98    def on_train_batch_start(
 99        self,
100        trainer: "pl.Trainer",
101        pl_module: "pl.LightningModule",
102        batch: Any,
103        batch_idx: int,
104    ) -> None:
105        """Hook that is called at the start of each training batch.
106
107        This method is used to optionally convert the input batch's tensors to a specified memory format.
108        If the `convert_input` flag is set to `False`, the method exits early without performing any conversion.
109        If the batch is not a `MutableSequence`, a warning is issued, and no conversion is performed.
110
111        Args:
112            trainer (pl.Trainer): The PyTorch Lightning Trainer instance.
113            pl_module (pl.LightningModule): The LightningModule being trained.
114            batch (Any): The input batch of data.
115            batch_idx (int): The index of the current batch.
116
117        Returns:
118            None
119        """
120        if not self.convert_input:
121            return
122
123        if not isinstance(batch, MutableSequence):
124            rank_zero_warn(
125                f"batch is not a MutableSequence, cannot convert input to {self.memory_format}",
126                category=RuntimeWarning,
127            )
128            return
129
130        for i, item in enumerate(batch):
131            if isinstance(item, torch.Tensor) and item.ndim == 4:
132                batch[i] = item.to(memory_format=self.memory_format)

Hook that is called at the start of each training batch.

This method is used to optionally convert the input batch's tensors to a specified memory format. If the convert_input flag is set to False, the method exits early without performing any conversion. If the batch is not a MutableSequence, a warning is issued, and no conversion is performed.

Args: trainer (pl.Trainer): The PyTorch Lightning Trainer instance. pl_module (pl.LightningModule): The LightningModule being trained. batch (Any): The input batch of data. batch_idx (int): The index of the current batch.

Returns: None

benefitial_layers = (<class 'torch.nn.modules.batchnorm.BatchNorm2d'>, <class 'torch.nn.modules.batchnorm.BatchNorm3d'>, <class 'torch.nn.modules.conv.Conv2d'>, <class 'torch.nn.modules.conv.Conv3d'>)
def has_layer_benefiting_from_channels_last(self, model: torch.nn.modules.module.Module) -> bool:
141    def has_layer_benefiting_from_channels_last(
142        self, model: torch.nn.Module
143    ) -> bool:
144        """Determines if the given model contains any layer that would benefit
145        from using the "channels_last" memory format.
146
147        Args:
148            model (torch.nn.Module): The PyTorch model to inspect.
149
150        Returns:
151            bool: True if at least one layer in the model is an instance of
152            the types specified in `self.benefitial_layers`, otherwise False.
153        """
154        return any(
155            isinstance(layer, self.benefitial_layers)
156            for layer in model.modules()
157        )

Determines if the given model contains any layer that would benefit from using the "channels_last" memory format.

Args: model (torch.nn.Module): The PyTorch model to inspect.

Returns: bool: True if at least one layer in the model is an instance of the types specified in self.benefitial_layers, otherwise False.