|
''' |
|
This is to save and load the model. |
|
''' |
|
|
|
def check_model_checkpoint_consistency(ckpt_state_dict, model_state_dict, special_strs=None): |
|
""" |
|
Maintain all checkpoint keys. Ignore keys with specific endings if absent. |
|
Raise exception for model keys not in checkpoint unless ignored. |
|
ckpt: The state dictionary of the checkpoint. |
|
model_state_dict: The state dictionary of the model. |
|
special_endings: A list of specific endings of strings to be ignored. |
|
""" |
|
filtered_ckpt = {} |
|
special_modules =[] |
|
for key in model_state_dict.keys(): |
|
if key in ckpt_state_dict: |
|
filtered_ckpt[key] = ckpt_state_dict[key] |
|
elif any(special_str in key for special_str in special_strs): |
|
special_modules.append(key) |
|
continue |
|
else: |
|
raise KeyError(f"Key '{key}' not found in checkpoint and does not match any special endings.") |
|
|
|
def remove_module_prefix(state_dict): |
|
new_state_dict = {} |
|
for key, value in state_dict.items(): |
|
if key.startswith('module.'): |
|
new_key = key[len('module.'):] |
|
new_state_dict[new_key] = value |
|
else: |
|
new_state_dict[key] = value |
|
return new_state_dict |
|
|
|
|
|
|
|
|
|
def zero_module(module): |
|
""" |
|
Zero out the parameters of a module and return it. |
|
""" |
|
for p in module.parameters(): |
|
p.detach().zero_() |
|
return module |
|
|
|
|
|
def filter_model_checkpoint(ckpt_state_dict, model_state_dict, need_strs=None): |
|
filtered_ckpt = {} |
|
for key in model_state_dict.keys(): |
|
if key in ckpt_state_dict and any(need_str in key for need_str in need_strs): |
|
filtered_ckpt[key] = ckpt_state_dict[key] |
|
else: |
|
continue |
|
|
|
return filtered_ckpt |
|
|