ptame.print_map
1from typing import Any, Dict, Tuple 2 3import hydra 4import torch 5from omegaconf import DictConfig 6 7from ptame.models.ptame_module import PTAMELitModule # noqa: E402 8from ptame.utils import RankedLogger, extras # noqa: E402 9from ptame.utils.map_printer import SampleSet # noqa: E402 10 11log = RankedLogger(__name__, rank_zero_only=True) 12 13 14def print_map(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 15 """Prints maps for the given model and subset.""" 16 17 log.info(f"Instantiating datamodule <{cfg.data._target_}>") 18 sampleset: SampleSet = hydra.utils.instantiate(cfg.data) 19 20 log.info(f"Instantiating model <{cfg.model._target_}>") 21 if not cfg.get("cam"): 22 model: PTAMELitModule = hydra.utils.instantiate(cfg.model) 23 net = model.net 24 25 if ckpt_path := cfg.get("ckpt_path"): 26 log.info(f"Loading model from checkpoint path {ckpt_path}") 27 loaded_sd = torch.load(ckpt_path, weights_only=True)["state_dict"] 28 if any("_orig_mod" in k for k in loaded_sd.keys()): 29 log.info("Loading original model weights") 30 keys = list(loaded_sd.keys()) 31 for key in keys: 32 if "_orig_mod" in key: 33 loaded_sd[key.replace("_orig_mod.", "")] = ( 34 loaded_sd.pop(key) 35 ) 36 incompatible_keys = model.load_state_dict( 37 loaded_sd, 38 strict=False, 39 ) 40 log.info( 41 "Incompatible keys: ", 42 ", ".join(a), 43 ) if len( 44 a := [i for i in incompatible_keys[0] if "backbone" not in i] 45 ) else log.info("No incompatible keys found in model state dict") 46 log.info("Unexpected keys: ", incompatible_keys[1]) if len( 47 incompatible_keys[1] 48 ) > 0 else log.info("No unexpected keys found in model state dict") 49 50 if aux_path := cfg.get("aux_ckpt_path"): 51 log.info( 52 f"Loading auxiliary model from checkpoint path {aux_path}" 53 ) 54 aux_ckpt = torch.load(aux_path, weights_only=True)["state_dict"] 55 aux_ckpt = { 56 k.replace("net.", ""): v 57 for k, v in aux_ckpt.items() 58 if "fc" not in k 59 } 60 incompatible_keys = net.attention.attention.load_state_dict( 61 aux_ckpt, 62 strict=False, 63 ) 64 # check if the model is in eval mode and has no gradients 65 log.info( 66 f"Aux. model in eval mode: {not net.attention.attention.training}" 67 ) 68 log.info( 69 f"Aux. model has no gradients: {not any(p.requires_grad for p in net.attention.attention.parameters())}" 70 ) 71 # get compatible keys 72 concrete_keys = list( 73 set(net.attention.attention.state_dict().keys()) 74 - set(incompatible_keys[0]) 75 ) 76 display_keys = ( 77 ( 78 "Compatible keys: " 79 + " ".join(concrete_keys[:10]) 80 + f"... {len(concrete_keys) - 10} more items remaining" 81 if len(concrete_keys) > 10 82 else concrete_keys 83 ) 84 if len(incompatible_keys[0]) > 10 85 else f"Incompatible keys: {incompatible_keys[0]}" 86 ) 87 log.info(display_keys) 88 log.info(f"Unexpected keys: {incompatible_keys[1]}") 89 else: 90 net = hydra.utils.instantiate(cfg.model) 91 92 log.info(f"Instantiating map printer <{cfg.map_printer._target_}>") 93 # map_printer = hydra.utils.instantiate() 94 map_printer = hydra.utils.instantiate( 95 cfg.map_printer, model=net.eval(), sampleset=sampleset 96 ) 97 98 log.info("Starting printing maps!") 99 try: 100 map_printer.print_maps() 101 log.info("Success!") 102 except Exception as e: 103 log.error("Failed to print maps!") 104 raise e 105 106 107@hydra.main( 108 version_base="1.3", config_path="configs", config_name="print_map.yaml" 109) 110def main(cfg: DictConfig) -> None: 111 """Main entry point for printing maps. 112 113 :param cfg: DictConfig configuration composed by Hydra. 114 """ 115 # apply extra utilities 116 # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 117 118 extras(cfg) 119 120 print_map(cfg) 121 122 123if __name__ == "__main__": 124 main()
log =
<RankedLogger ptame.print_map (INFO)>
def
print_map( cfg: omegaconf.dictconfig.DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
15def print_map(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 16 """Prints maps for the given model and subset.""" 17 18 log.info(f"Instantiating datamodule <{cfg.data._target_}>") 19 sampleset: SampleSet = hydra.utils.instantiate(cfg.data) 20 21 log.info(f"Instantiating model <{cfg.model._target_}>") 22 if not cfg.get("cam"): 23 model: PTAMELitModule = hydra.utils.instantiate(cfg.model) 24 net = model.net 25 26 if ckpt_path := cfg.get("ckpt_path"): 27 log.info(f"Loading model from checkpoint path {ckpt_path}") 28 loaded_sd = torch.load(ckpt_path, weights_only=True)["state_dict"] 29 if any("_orig_mod" in k for k in loaded_sd.keys()): 30 log.info("Loading original model weights") 31 keys = list(loaded_sd.keys()) 32 for key in keys: 33 if "_orig_mod" in key: 34 loaded_sd[key.replace("_orig_mod.", "")] = ( 35 loaded_sd.pop(key) 36 ) 37 incompatible_keys = model.load_state_dict( 38 loaded_sd, 39 strict=False, 40 ) 41 log.info( 42 "Incompatible keys: ", 43 ", ".join(a), 44 ) if len( 45 a := [i for i in incompatible_keys[0] if "backbone" not in i] 46 ) else log.info("No incompatible keys found in model state dict") 47 log.info("Unexpected keys: ", incompatible_keys[1]) if len( 48 incompatible_keys[1] 49 ) > 0 else log.info("No unexpected keys found in model state dict") 50 51 if aux_path := cfg.get("aux_ckpt_path"): 52 log.info( 53 f"Loading auxiliary model from checkpoint path {aux_path}" 54 ) 55 aux_ckpt = torch.load(aux_path, weights_only=True)["state_dict"] 56 aux_ckpt = { 57 k.replace("net.", ""): v 58 for k, v in aux_ckpt.items() 59 if "fc" not in k 60 } 61 incompatible_keys = net.attention.attention.load_state_dict( 62 aux_ckpt, 63 strict=False, 64 ) 65 # check if the model is in eval mode and has no gradients 66 log.info( 67 f"Aux. model in eval mode: {not net.attention.attention.training}" 68 ) 69 log.info( 70 f"Aux. model has no gradients: {not any(p.requires_grad for p in net.attention.attention.parameters())}" 71 ) 72 # get compatible keys 73 concrete_keys = list( 74 set(net.attention.attention.state_dict().keys()) 75 - set(incompatible_keys[0]) 76 ) 77 display_keys = ( 78 ( 79 "Compatible keys: " 80 + " ".join(concrete_keys[:10]) 81 + f"... {len(concrete_keys) - 10} more items remaining" 82 if len(concrete_keys) > 10 83 else concrete_keys 84 ) 85 if len(incompatible_keys[0]) > 10 86 else f"Incompatible keys: {incompatible_keys[0]}" 87 ) 88 log.info(display_keys) 89 log.info(f"Unexpected keys: {incompatible_keys[1]}") 90 else: 91 net = hydra.utils.instantiate(cfg.model) 92 93 log.info(f"Instantiating map printer <{cfg.map_printer._target_}>") 94 # map_printer = hydra.utils.instantiate() 95 map_printer = hydra.utils.instantiate( 96 cfg.map_printer, model=net.eval(), sampleset=sampleset 97 ) 98 99 log.info("Starting printing maps!") 100 try: 101 map_printer.print_maps() 102 log.info("Success!") 103 except Exception as e: 104 log.error("Failed to print maps!") 105 raise e
Prints maps for the given model and subset.
@hydra.main(version_base='1.3', config_path='configs', config_name='print_map.yaml')
def
main(cfg: omegaconf.dictconfig.DictConfig) -> None:
108@hydra.main( 109 version_base="1.3", config_path="configs", config_name="print_map.yaml" 110) 111def main(cfg: DictConfig) -> None: 112 """Main entry point for printing maps. 113 114 :param cfg: DictConfig configuration composed by Hydra. 115 """ 116 # apply extra utilities 117 # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 118 119 extras(cfg) 120 121 print_map(cfg)
Main entry point for printing maps.
Parameters
- cfg: DictConfig configuration composed by Hydra.