alexzyqi's picture
20240706
52d68d4
'''
This is to save and load the model.
'''
def check_model_checkpoint_consistency(ckpt_state_dict, model_state_dict, special_strs=None):
"""
Maintain all checkpoint keys. Ignore keys with specific endings if absent.
Raise exception for model keys not in checkpoint unless ignored.
ckpt: The state dictionary of the checkpoint.
model_state_dict: The state dictionary of the model.
special_endings: A list of specific endings of strings to be ignored.
"""
filtered_ckpt = {}
special_modules =[]
for key in model_state_dict.keys():
if key in ckpt_state_dict:
filtered_ckpt[key] = ckpt_state_dict[key]
elif any(special_str in key for special_str in special_strs):
special_modules.append(key)
continue
else:
raise KeyError(f"Key '{key}' not found in checkpoint and does not match any special endings.")
def remove_module_prefix(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith('module.'):
new_key = key[len('module.'):]
new_state_dict[new_key] = value
else:
new_state_dict[key] = value
return new_state_dict
# This is for reducing impact at the beginning of training.
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def filter_model_checkpoint(ckpt_state_dict, model_state_dict, need_strs=None):
filtered_ckpt = {}
for key in model_state_dict.keys():
if key in ckpt_state_dict and any(need_str in key for need_str in need_strs):
filtered_ckpt[key] = ckpt_state_dict[key]
else:
continue
return filtered_ckpt