File size: 7,632 Bytes
14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
"""This file contains the definition of base classes.
We thank the following public implementations for inspiring this code:
https://github.com/huggingface/open-muse/blob/main/muse/modeling_utils.py
"""
import copy
import os
from typing import Union, Callable, Tuple, Dict, Optional, List
import torch
def get_parameter_device(parameter: torch.nn.Module):
try:
return next(parameter.parameters()).device
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(
module: torch.nn.Module,
) -> List[Tuple[str, torch.Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].device
def get_parameter_dtype(parameter: torch.nn.Module):
try:
return next(parameter.parameters()).dtype
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(
module: torch.nn.Module,
) -> List[Tuple[str, torch.Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
class BaseModel(torch.nn.Module):
def __init__(self):
super().__init__()
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
save_function: Callable = None,
state_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Save a model to a directory, so that it can be re-loaded using the
load_pretrained class method.
Args:
save_directory -> Union[str, os.PathLike]: Directory to which to save. Will be created
if it doesn't exist.
save_function -> Optional[Callable]: The function to use to save the state dictionary.
Useful on distributed training like TPUs when one need to replace `torch.save` by another method.
state_dict -> Optional[Dict[str, torch.Tensor]]: The state dictionary to save. If `None`, the model's
state dictionary will be saved.
"""
if os.path.isfile(save_directory):
print(f"Provided path ({save_directory}) should be a directory, not a file")
return
if save_function is None:
save_function = torch.save
os.makedirs(save_directory, exist_ok=True)
model_to_save = self
# Save the model
if state_dict is None:
state_dict = model_to_save.state_dict()
weights_name = "pytorch_model.bin"
# Save the model
save_function(state_dict, os.path.join(save_directory, weights_name))
print(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
def load_pretrained(
self,
# pretrained_model_path: Union[str, os.PathLike],
checkpoint,
strict_loading: bool = True,
torch_dtype: Optional[torch.dtype] = None,
rename_keys: Optional[Dict[str, str]] = None,
):
"""Instantiate a pretrained pytorch model from a weights path.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
To train the model, you should first set it back in training mode with `model.train()`.
Args:
pretrained_model_path -> Union[str, os.PathLike]: Path to a pretrained model.
strict_loading -> bool: Whether or not to strictly enforce that the provided weights file matches the
architecture of this model.
torch_dtype -> Optional[torch.dtype]: The dtype to use for the model. Defaults to `None`, which means
no conversion.
rename_keys -> Optional[Dict[str, str]]: A dictionary containing the keys to rename.
Defaults to `None`, which means no renaming.
"""
# if os.path.isfile(pretrained_model_path):
# model_file = pretrained_model_path
# elif os.path.isdir(pretrained_model_path):
# pretrained_model_path = os.path.join(pretrained_model_path, "pytorch_model.bin")
# if os.path.isfile(pretrained_model_path):
# model_file = pretrained_model_path
# else:
# raise ValueError(f"{pretrained_model_path} does not exist")
# else:
# raise ValueError(f"{pretrained_model_path} does not exist")
#
# checkpoint = torch.load(model_file, map_location="cpu")
new_checkpoint = copy.deepcopy(checkpoint)
if rename_keys is not None:
for p_key in checkpoint:
for r_key in rename_keys:
if p_key.startswith(r_key):
new_checkpoint[p_key.replace(r_key, rename_keys[r_key])] = (
checkpoint[p_key]
)
new_checkpoint.pop(p_key)
break
checkpoint = new_checkpoint
self.load_state_dict(checkpoint, strict=strict_loading)
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
)
elif torch_dtype is not None:
self.to(torch_dtype)
# Set model in evaluation mode to deactivate DropOut modules by default
self.eval()
@property
def device(self):
"""Returns the device of the model.
Returns:
`torch.device`: The device of the model.
"""
return get_parameter_device(self)
@property
def dtype(self) -> torch.dtype:
"""Returns the dtype of the model."""
return get_parameter_dtype(self)
def num_parameters(
self, only_trainable: bool = False, exclude_embeddings: bool = False
) -> int:
"""
Get number of (optionally, trainable or non-embeddings) parameters in the module.
Args:
only_trainable (`bool`, *optional*, defaults to `False`):
Whether or not to return only the number of trainable parameters
exclude_embeddings (`bool`, *optional*, defaults to `False`):
Whether or not to return only the number of non-embeddings parameters
Returns:
`int`: The number of parameters.
"""
if exclude_embeddings:
embedding_param_names = [
f"{name}.weight"
for name, module_type in self.named_modules()
if isinstance(module_type, torch.nn.Embedding)
]
non_embedding_parameters = [
parameter
for name, parameter in self.named_parameters()
if name not in embedding_param_names
]
return sum(
p.numel()
for p in non_embedding_parameters
if p.requires_grad or not only_trainable
)
else:
return sum(
p.numel()
for p in self.parameters()
if p.requires_grad or not only_trainable
)
|