File size: 1,815 Bytes
52d68d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
'''
This is to save and load the model.
'''

def check_model_checkpoint_consistency(ckpt_state_dict, model_state_dict, special_strs=None):
    """
    Maintain all checkpoint keys. Ignore keys with specific endings if absent. 
    Raise exception for model keys not in checkpoint unless ignored.
    ckpt: The state dictionary of the checkpoint.
    model_state_dict: The state dictionary of the model.
    special_endings: A list of specific endings of strings to be ignored.
    """
    filtered_ckpt = {}
    special_modules =[]
    for key in model_state_dict.keys():
        if key in ckpt_state_dict:
            filtered_ckpt[key] = ckpt_state_dict[key]
        elif any(special_str in key for special_str in special_strs):
            special_modules.append(key)
            continue
        else:
            raise KeyError(f"Key '{key}' not found in checkpoint and does not match any special endings.")
        
def remove_module_prefix(state_dict):
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith('module.'):
            new_key = key[len('module.'):]
            new_state_dict[new_key] = value
        else:
            new_state_dict[key] = value
    return new_state_dict



# This is for reducing impact at the beginning of training.
def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def filter_model_checkpoint(ckpt_state_dict, model_state_dict, need_strs=None):
    filtered_ckpt = {}
    for key in model_state_dict.keys():
        if key in ckpt_state_dict and any(need_str in key for need_str in need_strs):
            filtered_ckpt[key] = ckpt_state_dict[key]
        else:
            continue

    return filtered_ckpt