Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn | |
from lora.lora_layers import LoraInjectedLinear, LoraInjectedConv2d | |
def _find_modules(model, ancestor_class=None, search_class=[nn.Linear], exclude_children_of=[LoraInjectedLinear]): | |
# Get the targets we should replace all linears under | |
if ancestor_class is not None: | |
ancestors = ( | |
module | |
for module in model.modules() | |
if module.__class__.__name__ in ancestor_class | |
) | |
else: | |
# this, incase you want to naively iterate over all modules. | |
ancestors = [module for module in model.modules()] | |
for ancestor in ancestors: | |
for fullname, module in ancestor.named_modules(): | |
# if 'norm1_context' in fullname: | |
if any([isinstance(module, _class) for _class in search_class]): | |
*path, name = fullname.split(".") | |
parent = ancestor | |
while path: | |
parent = parent.get_submodule(path.pop(0)) | |
if exclude_children_of and any( | |
[isinstance(parent, _class) for _class in exclude_children_of] | |
): | |
continue | |
yield parent, name, module | |
def extract_lora_ups_down(model, target_replace_module={'AdaLayerNormZero'}): # Attention for kv_lora | |
loras = [] | |
for _m, _n, _child_module in _find_modules( | |
model, | |
target_replace_module, | |
search_class=[LoraInjectedLinear, LoraInjectedConv2d], | |
): | |
loras.append((_child_module.lora_up, _child_module.lora_down)) | |
if len(loras) == 0: | |
raise ValueError("No lora injected.") | |
return loras | |
def save_lora_weight( | |
model, | |
path="./lora.pt", | |
target_replace_module={'AdaLayerNormZero'}, # Attention for kv_lora | |
save_half:bool=False | |
): | |
weights = [] | |
for _up, _down in extract_lora_ups_down( | |
model, target_replace_module=target_replace_module | |
): | |
dtype = torch.float16 if save_half else torch.float32 | |
weights.append(_up.weight.to("cpu").to(dtype)) | |
weights.append(_down.weight.to("cpu").to(dtype)) | |
torch.save(weights, path) |