Spaces:
Runtime error
Runtime error
File size: 1,815 Bytes
52d68d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
'''
This is to save and load the model.
'''
def check_model_checkpoint_consistency(ckpt_state_dict, model_state_dict, special_strs=None):
"""
Maintain all checkpoint keys. Ignore keys with specific endings if absent.
Raise exception for model keys not in checkpoint unless ignored.
ckpt: The state dictionary of the checkpoint.
model_state_dict: The state dictionary of the model.
special_endings: A list of specific endings of strings to be ignored.
"""
filtered_ckpt = {}
special_modules =[]
for key in model_state_dict.keys():
if key in ckpt_state_dict:
filtered_ckpt[key] = ckpt_state_dict[key]
elif any(special_str in key for special_str in special_strs):
special_modules.append(key)
continue
else:
raise KeyError(f"Key '{key}' not found in checkpoint and does not match any special endings.")
def remove_module_prefix(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith('module.'):
new_key = key[len('module.'):]
new_state_dict[new_key] = value
else:
new_state_dict[key] = value
return new_state_dict
# This is for reducing impact at the beginning of training.
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def filter_model_checkpoint(ckpt_state_dict, model_state_dict, need_strs=None):
filtered_ckpt = {}
for key in model_state_dict.keys():
if key in ckpt_state_dict and any(need_str in key for need_str in need_strs):
filtered_ckpt[key] = ckpt_state_dict[key]
else:
continue
return filtered_ckpt
|