ptame.models.components.rand_model
1import torch 2 3 4def cascading_random_model( 5 model: torch.nn.Module, 6 rand_model: torch.nn.Module, 7 layer_key: str = "layer1", 8 random_before: bool = True, 9 random_after: bool = False, 10) -> torch.nn.Module: 11 """ 12 Randomize all layers that contain the specified key in their name. 13 """ 14 before = random_before 15 encounter = False 16 after = random_after 17 mdl = model.state_dict() 18 rand_mdl = rand_model.state_dict() 19 for name in mdl.keys(): 20 if layer_key in name: 21 mdl[name] = rand_mdl[name] 22 before = False 23 encounter = True 24 elif before: 25 mdl[name] = rand_mdl[name] 26 elif encounter and after: 27 mdl[name] = rand_mdl[name] 28 model.load_state_dict(mdl) 29 return model
def
cascading_random_model( model: torch.nn.modules.module.Module, rand_model: torch.nn.modules.module.Module, layer_key: str = 'layer1', random_before: bool = True, random_after: bool = False) -> torch.nn.modules.module.Module:
5def cascading_random_model( 6 model: torch.nn.Module, 7 rand_model: torch.nn.Module, 8 layer_key: str = "layer1", 9 random_before: bool = True, 10 random_after: bool = False, 11) -> torch.nn.Module: 12 """ 13 Randomize all layers that contain the specified key in their name. 14 """ 15 before = random_before 16 encounter = False 17 after = random_after 18 mdl = model.state_dict() 19 rand_mdl = rand_model.state_dict() 20 for name in mdl.keys(): 21 if layer_key in name: 22 mdl[name] = rand_mdl[name] 23 before = False 24 encounter = True 25 elif before: 26 mdl[name] = rand_mdl[name] 27 elif encounter and after: 28 mdl[name] = rand_mdl[name] 29 model.load_state_dict(mdl) 30 return model
Randomize all layers that contain the specified key in their name.