ptame.print_map

  1from typing import Any, Dict, Tuple
  2
  3import hydra
  4import torch
  5from omegaconf import DictConfig
  6
  7from ptame.models.ptame_module import PTAMELitModule  # noqa: E402
  8from ptame.utils import RankedLogger, extras  # noqa: E402
  9from ptame.utils.map_printer import SampleSet  # noqa: E402
 10
 11log = RankedLogger(__name__, rank_zero_only=True)
 12
 13
 14def print_map(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
 15    """Prints maps for the given model and subset."""
 16
 17    log.info(f"Instantiating datamodule <{cfg.data._target_}>")
 18    sampleset: SampleSet = hydra.utils.instantiate(cfg.data)
 19
 20    log.info(f"Instantiating model <{cfg.model._target_}>")
 21    if not cfg.get("cam"):
 22        model: PTAMELitModule = hydra.utils.instantiate(cfg.model)
 23        net = model.net
 24
 25        if ckpt_path := cfg.get("ckpt_path"):
 26            log.info(f"Loading model from checkpoint path {ckpt_path}")
 27            loaded_sd = torch.load(ckpt_path, weights_only=True)["state_dict"]
 28            if any("_orig_mod" in k for k in loaded_sd.keys()):
 29                log.info("Loading original model weights")
 30                keys = list(loaded_sd.keys())
 31                for key in keys:
 32                    if "_orig_mod" in key:
 33                        loaded_sd[key.replace("_orig_mod.", "")] = (
 34                            loaded_sd.pop(key)
 35                        )
 36            incompatible_keys = model.load_state_dict(
 37                loaded_sd,
 38                strict=False,
 39            )
 40            log.info(
 41                "Incompatible keys: ",
 42                ", ".join(a),
 43            ) if len(
 44                a := [i for i in incompatible_keys[0] if "backbone" not in i]
 45            ) else log.info("No incompatible keys found in model state dict")
 46            log.info("Unexpected keys: ", incompatible_keys[1]) if len(
 47                incompatible_keys[1]
 48            ) > 0 else log.info("No unexpected keys found in model state dict")
 49
 50        if aux_path := cfg.get("aux_ckpt_path"):
 51            log.info(
 52                f"Loading auxiliary model from checkpoint path {aux_path}"
 53            )
 54            aux_ckpt = torch.load(aux_path, weights_only=True)["state_dict"]
 55            aux_ckpt = {
 56                k.replace("net.", ""): v
 57                for k, v in aux_ckpt.items()
 58                if "fc" not in k
 59            }
 60            incompatible_keys = net.attention.attention.load_state_dict(
 61                aux_ckpt,
 62                strict=False,
 63            )
 64            # check if the model is in eval mode and has no gradients
 65            log.info(
 66                f"Aux. model in eval mode: {not net.attention.attention.training}"
 67            )
 68            log.info(
 69                f"Aux. model has no gradients: {not any(p.requires_grad for p in net.attention.attention.parameters())}"
 70            )
 71            # get compatible keys
 72            concrete_keys = list(
 73                set(net.attention.attention.state_dict().keys())
 74                - set(incompatible_keys[0])
 75            )
 76            display_keys = (
 77                (
 78                    "Compatible keys: "
 79                    + " ".join(concrete_keys[:10])
 80                    + f"... {len(concrete_keys) - 10} more items remaining"
 81                    if len(concrete_keys) > 10
 82                    else concrete_keys
 83                )
 84                if len(incompatible_keys[0]) > 10
 85                else f"Incompatible keys: {incompatible_keys[0]}"
 86            )
 87            log.info(display_keys)
 88            log.info(f"Unexpected keys: {incompatible_keys[1]}")
 89    else:
 90        net = hydra.utils.instantiate(cfg.model)
 91
 92    log.info(f"Instantiating map printer <{cfg.map_printer._target_}>")
 93    # map_printer = hydra.utils.instantiate()
 94    map_printer = hydra.utils.instantiate(
 95        cfg.map_printer, model=net.eval(), sampleset=sampleset
 96    )
 97
 98    log.info("Starting printing maps!")
 99    try:
100        map_printer.print_maps()
101        log.info("Success!")
102    except Exception as e:
103        log.error("Failed to print maps!")
104        raise e
105
106
107@hydra.main(
108    version_base="1.3", config_path="configs", config_name="print_map.yaml"
109)
110def main(cfg: DictConfig) -> None:
111    """Main entry point for printing maps.
112
113    :param cfg: DictConfig configuration composed by Hydra.
114    """
115    # apply extra utilities
116    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
117
118    extras(cfg)
119
120    print_map(cfg)
121
122
123if __name__ == "__main__":
124    main()
log = <RankedLogger ptame.print_map (INFO)>
@hydra.main(version_base='1.3', config_path='configs', config_name='print_map.yaml')
def main(cfg: omegaconf.dictconfig.DictConfig) -> None:
108@hydra.main(
109    version_base="1.3", config_path="configs", config_name="print_map.yaml"
110)
111def main(cfg: DictConfig) -> None:
112    """Main entry point for printing maps.
113
114    :param cfg: DictConfig configuration composed by Hydra.
115    """
116    # apply extra utilities
117    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
118
119    extras(cfg)
120
121    print_map(cfg)

Main entry point for printing maps.

Parameters
  • cfg: DictConfig configuration composed by Hydra.