ptame.models.components.rand_model

 1import torch
 2
 3
 4def cascading_random_model(
 5    model: torch.nn.Module,
 6    rand_model: torch.nn.Module,
 7    layer_key: str = "layer1",
 8    random_before: bool = True,
 9    random_after: bool = False,
10) -> torch.nn.Module:
11    """
12    Randomize all layers that contain the specified key in their name.
13    """
14    before = random_before
15    encounter = False
16    after = random_after
17    mdl = model.state_dict()
18    rand_mdl = rand_model.state_dict()
19    for name in mdl.keys():
20        if layer_key in name:
21            mdl[name] = rand_mdl[name]
22            before = False
23            encounter = True
24        elif before:
25            mdl[name] = rand_mdl[name]
26        elif encounter and after:
27            mdl[name] = rand_mdl[name]
28    model.load_state_dict(mdl)
29    return model
def cascading_random_model( model: torch.nn.modules.module.Module, rand_model: torch.nn.modules.module.Module, layer_key: str = 'layer1', random_before: bool = True, random_after: bool = False) -> torch.nn.modules.module.Module:
 5def cascading_random_model(
 6    model: torch.nn.Module,
 7    rand_model: torch.nn.Module,
 8    layer_key: str = "layer1",
 9    random_before: bool = True,
10    random_after: bool = False,
11) -> torch.nn.Module:
12    """
13    Randomize all layers that contain the specified key in their name.
14    """
15    before = random_before
16    encounter = False
17    after = random_after
18    mdl = model.state_dict()
19    rand_mdl = rand_model.state_dict()
20    for name in mdl.keys():
21        if layer_key in name:
22            mdl[name] = rand_mdl[name]
23            before = False
24            encounter = True
25        elif before:
26            mdl[name] = rand_mdl[name]
27        elif encounter and after:
28            mdl[name] = rand_mdl[name]
29    model.load_state_dict(mdl)
30    return model

Randomize all layers that contain the specified key in their name.