Chain-of-Zoom / lora /lora_utils.py
alexnasa's picture
Upload 54 files
0301e15 verified
import torch
from torch import nn
from lora.lora_layers import LoraInjectedLinear, LoraInjectedConv2d
def _find_modules(model, ancestor_class=None, search_class=[nn.Linear], exclude_children_of=[LoraInjectedLinear]):
# Get the targets we should replace all linears under
if ancestor_class is not None:
ancestors = (
module
for module in model.modules()
if module.__class__.__name__ in ancestor_class
)
else:
# this, incase you want to naively iterate over all modules.
ancestors = [module for module in model.modules()]
for ancestor in ancestors:
for fullname, module in ancestor.named_modules():
# if 'norm1_context' in fullname:
if any([isinstance(module, _class) for _class in search_class]):
*path, name = fullname.split(".")
parent = ancestor
while path:
parent = parent.get_submodule(path.pop(0))
if exclude_children_of and any(
[isinstance(parent, _class) for _class in exclude_children_of]
):
continue
yield parent, name, module
def extract_lora_ups_down(model, target_replace_module={'AdaLayerNormZero'}): # Attention for kv_lora
loras = []
for _m, _n, _child_module in _find_modules(
model,
target_replace_module,
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
):
loras.append((_child_module.lora_up, _child_module.lora_down))
if len(loras) == 0:
raise ValueError("No lora injected.")
return loras
def save_lora_weight(
model,
path="./lora.pt",
target_replace_module={'AdaLayerNormZero'}, # Attention for kv_lora
save_half:bool=False
):
weights = []
for _up, _down in extract_lora_ups_down(
model, target_replace_module=target_replace_module
):
dtype = torch.float16 if save_half else torch.float32
weights.append(_up.weight.to("cpu").to(dtype))
weights.append(_down.weight.to("cpu").to(dtype))
torch.save(weights, path)