Spaces:
Runtime error
Runtime error
''' | |
This is to save and load the model. | |
''' | |
def check_model_checkpoint_consistency(ckpt_state_dict, model_state_dict, special_strs=None): | |
""" | |
Maintain all checkpoint keys. Ignore keys with specific endings if absent. | |
Raise exception for model keys not in checkpoint unless ignored. | |
ckpt: The state dictionary of the checkpoint. | |
model_state_dict: The state dictionary of the model. | |
special_endings: A list of specific endings of strings to be ignored. | |
""" | |
filtered_ckpt = {} | |
special_modules =[] | |
for key in model_state_dict.keys(): | |
if key in ckpt_state_dict: | |
filtered_ckpt[key] = ckpt_state_dict[key] | |
elif any(special_str in key for special_str in special_strs): | |
special_modules.append(key) | |
continue | |
else: | |
raise KeyError(f"Key '{key}' not found in checkpoint and does not match any special endings.") | |
def remove_module_prefix(state_dict): | |
new_state_dict = {} | |
for key, value in state_dict.items(): | |
if key.startswith('module.'): | |
new_key = key[len('module.'):] | |
new_state_dict[new_key] = value | |
else: | |
new_state_dict[key] = value | |
return new_state_dict | |
# This is for reducing impact at the beginning of training. | |
def zero_module(module): | |
""" | |
Zero out the parameters of a module and return it. | |
""" | |
for p in module.parameters(): | |
p.detach().zero_() | |
return module | |
def filter_model_checkpoint(ckpt_state_dict, model_state_dict, need_strs=None): | |
filtered_ckpt = {} | |
for key in model_state_dict.keys(): | |
if key in ckpt_state_dict and any(need_str in key for need_str in need_strs): | |
filtered_ckpt[key] = ckpt_state_dict[key] | |
else: | |
continue | |
return filtered_ckpt | |