|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import warnings |
|
from typing import Any, List, Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge |
|
|
|
|
|
class HRALayer(BaseTunerLayer): |
|
|
|
adapter_layer_names = ("hra_u",) |
|
|
|
other_param_names = ("hra_r", "hra_apply_GS") |
|
|
|
def __init__(self, base_layer: nn.Module, **kwargs) -> None: |
|
self.base_layer = base_layer |
|
self.hra_r = {} |
|
self.hra_apply_GS = {} |
|
self.hra_u = nn.ParameterDict({}) |
|
|
|
self._disable_adapters = False |
|
self.merged_adapters = [] |
|
self.kwargs = kwargs |
|
|
|
base_layer = self.get_base_layer() |
|
if isinstance(base_layer, nn.Linear): |
|
self.in_features, self.out_features = base_layer.in_features, base_layer.out_features |
|
elif isinstance(base_layer, nn.Conv2d): |
|
self.in_features, self.out_features = base_layer.in_channels, base_layer.out_channels |
|
else: |
|
raise ValueError(f"Unsupported layer type {type(base_layer)}") |
|
|
|
def update_layer( |
|
self, |
|
adapter_name: str, |
|
r: int, |
|
apply_GS: bool, |
|
init_weights: bool, |
|
**kwargs, |
|
) -> None: |
|
"""Internal function to create hra adapter |
|
|
|
Args: |
|
adapter_name (`str`): Name for the adapter to add. |
|
r (`int`): Rank for the added adapter. |
|
init_weights (`bool`): Whether to initialize weights. |
|
apply_GS (`bool`): Whether to apply Gram-Schmidt orthogonalization or not. |
|
""" |
|
if r <= 0: |
|
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") |
|
|
|
self.hra_r[adapter_name] = r |
|
self.hra_apply_GS[adapter_name] = apply_GS |
|
|
|
|
|
base_layer = self.get_base_layer() |
|
if isinstance(base_layer, nn.Linear): |
|
self.hra_u[adapter_name] = nn.Parameter(torch.empty(self.in_features, r), requires_grad=True) |
|
elif isinstance(base_layer, nn.Conv2d): |
|
self.hra_u[adapter_name] = nn.Parameter( |
|
torch.empty(self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0], r), |
|
requires_grad=True, |
|
) |
|
else: |
|
raise TypeError(f"HRA is not implemented for base layers of type {type(base_layer).__name__}") |
|
|
|
|
|
if init_weights: |
|
self.reset_hra_parameters(adapter_name) |
|
else: |
|
self.reset_hra_parameters_random(adapter_name) |
|
|
|
|
|
self._move_adapter_to_device_of_base_layer(adapter_name) |
|
self.set_adapter(self.active_adapters) |
|
|
|
def reset_hra_parameters(self, adapter_name: str): |
|
if self.hra_r[adapter_name] % 2 != 0: |
|
warnings.warn("The symmetric initialization can NOT be performed when r is odd!") |
|
nn.init.kaiming_uniform_(self.hra_u[adapter_name], a=math.sqrt(5)) |
|
else: |
|
shape = self.hra_u[adapter_name].shape |
|
half_u = torch.zeros(shape[0], shape[1] // 2) |
|
nn.init.kaiming_uniform_(half_u, a=math.sqrt(5)) |
|
self.hra_u[adapter_name] = nn.Parameter(torch.repeat_interleave(half_u, 2, dim=1)) |
|
|
|
def reset_hra_parameters_random(self, adapter_name: str): |
|
nn.init.kaiming_uniform_(self.hra_u[adapter_name], a=math.sqrt(5)) |
|
|
|
def scale_layer(self, scale: float) -> None: |
|
if scale == 1: |
|
return |
|
|
|
for active_adapter in self.active_adapters: |
|
if active_adapter not in self.hra_u.keys(): |
|
continue |
|
|
|
warnings.warn("Scaling operation for HRA not supported! Automatically set scale to 1.") |
|
|
|
def unscale_layer(self, scale=None) -> None: |
|
for active_adapter in self.active_adapters: |
|
if active_adapter not in self.hra_u.keys(): |
|
continue |
|
|
|
warnings.warn("Unscaling operation for HRA not supported! Keeping scale at 1.") |
|
|
|
|
|
class HRALinear(nn.Module, HRALayer): |
|
""" |
|
HRA implemented in a dense layer. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
base_layer, |
|
adapter_name: str, |
|
r: int = 0, |
|
apply_GS: bool = False, |
|
init_weights: Union[bool, str] = True, |
|
**kwargs, |
|
) -> None: |
|
super().__init__() |
|
HRALayer.__init__(self, base_layer, **kwargs) |
|
self._active_adapter = adapter_name |
|
self.update_layer(adapter_name, r, apply_GS, init_weights, **kwargs) |
|
|
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: |
|
""" |
|
Merge the active adapter weights into the base weights |
|
|
|
Args: |
|
safe_merge (`bool`, *optional*): |
|
If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs |
|
before merging the weights. This is useful if you want to check if the merge operation will produce |
|
NaNs. Defaults to `False`. |
|
adapter_names (`List[str]`, *optional*): |
|
The list of adapter names that should be merged. If `None`, all active adapters will be merged. |
|
Defaults to `None`. |
|
""" |
|
adapter_names = check_adapters_to_merge(self, adapter_names) |
|
if not adapter_names: |
|
|
|
return |
|
|
|
for active_adapter in adapter_names: |
|
if active_adapter in self.hra_u.keys(): |
|
base_layer = self.get_base_layer() |
|
if safe_merge: |
|
|
|
|
|
orig_weight = base_layer.weight.data.clone() |
|
delta_weight = self.get_delta_weight(active_adapter) |
|
orig_weight = torch.mm(orig_weight, delta_weight) |
|
|
|
if not torch.isfinite(orig_weight).all(): |
|
raise ValueError( |
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" |
|
) |
|
|
|
self.base_layer.weight.data = orig_weight |
|
else: |
|
delta_weight = self.get_delta_weight(active_adapter) |
|
self.base_layer.weight.data = torch.mm(self.base_layer.weight.data, delta_weight) |
|
self.merged_adapters.append(active_adapter) |
|
|
|
def unmerge(self) -> None: |
|
""" |
|
This method unmerges all merged adapter layers from the base weights. |
|
""" |
|
if not self.merged: |
|
warnings.warn("Already unmerged. Nothing to do.") |
|
return |
|
while len(self.merged_adapters) > 0: |
|
active_adapter = self.merged_adapters.pop() |
|
if active_adapter in self.hra_u.keys(): |
|
orig_weight = self.get_base_layer().weight.data.clone() |
|
delta_weight = self.get_delta_weight(active_adapter, reverse=True) |
|
self.get_base_layer().weight.data = torch.mm(orig_weight, delta_weight) |
|
|
|
def get_delta_weight(self, adapter_name: str, reverse: bool = False) -> torch.Tensor: |
|
rank = self.hra_r[adapter_name] |
|
apply_GS = self.hra_apply_GS[adapter_name] |
|
opt_u = self.hra_u[adapter_name] |
|
shape = opt_u.shape |
|
|
|
if apply_GS: |
|
weight = [(opt_u[:, 0] / opt_u[:, 0].norm()).view(-1, 1)] |
|
for i in range(1, rank): |
|
ui = opt_u[:, i].view(-1, 1) |
|
for j in range(i): |
|
ui = ui - (weight[j].t() @ ui) * weight[j] |
|
weight.append((ui / ui.norm()).view(-1, 1)) |
|
weight = torch.cat(weight, dim=1) |
|
weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) - 2 * weight @ weight.t() |
|
|
|
else: |
|
opt_u = opt_u / opt_u.norm(dim=0) |
|
weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) |
|
if reverse: |
|
indices = range(rank - 1, -1, -1) |
|
else: |
|
indices = range(rank) |
|
|
|
for i in indices: |
|
ui = opt_u[:, i].view(-1, 1) |
|
weight = weight - 2 * weight @ ui @ ui.t() |
|
|
|
return weight |
|
|
|
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: |
|
previous_dtype = x.dtype |
|
|
|
if self.disable_adapters: |
|
if self.merged: |
|
self.unmerge() |
|
result = self.base_layer(x, *args, **kwargs) |
|
elif self.merged: |
|
result = self.base_layer(x, *args, **kwargs) |
|
else: |
|
new_weight = torch.eye(self.in_features, device=x.device) |
|
|
|
for active_adapter in self.active_adapters: |
|
if active_adapter not in self.hra_u.keys(): |
|
continue |
|
delta_weight = self.get_delta_weight(active_adapter) |
|
new_weight = torch.mm(new_weight, delta_weight) |
|
|
|
x = x.to(self.get_base_layer().weight.data.dtype) |
|
orig_weight = self.get_base_layer().weight.data |
|
new_weight = torch.mm(orig_weight, new_weight) |
|
|
|
result = F.linear(input=x, weight=new_weight, bias=self.base_layer.bias) |
|
|
|
result = result.to(previous_dtype) |
|
return result |
|
|
|
def __repr__(self) -> str: |
|
rep = super().__repr__() |
|
return "hra." + rep |
|
|
|
|
|
class HRAConv2d(nn.Module, HRALayer): |
|
"""HRA implemented in Conv2d layer""" |
|
|
|
def __init__( |
|
self, |
|
base_layer, |
|
adapter_name: str, |
|
r: int = 0, |
|
apply_GS: bool = False, |
|
init_weights: Union[bool, str] = True, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
HRALayer.__init__(self, base_layer) |
|
self._active_adapter = adapter_name |
|
self.update_layer(adapter_name, r, apply_GS, init_weights, **kwargs) |
|
|
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: |
|
""" |
|
Merge the active adapter weights into the base weights |
|
|
|
Args: |
|
safe_merge (`bool`, *optional*): |
|
If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs |
|
before merging the weights. This is useful if you want to check if the merge operation will produce |
|
NaNs. Defaults to `False`. |
|
adapter_names (`List[str]`, *optional*): |
|
The list of adapter names that should be merged. If `None`, all active adapters will be merged. |
|
Defaults to `None`. |
|
""" |
|
adapter_names = check_adapters_to_merge(self, adapter_names) |
|
if not adapter_names: |
|
|
|
return |
|
|
|
for active_adapter in adapter_names: |
|
if active_adapter in self.hra_u.keys(): |
|
base_layer = self.get_base_layer() |
|
if safe_merge: |
|
|
|
|
|
orig_weight = base_layer.weight.data.clone() |
|
orig_weight = orig_weight.view( |
|
self.out_features, |
|
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], |
|
) |
|
delta_weight = self.get_delta_weight(active_adapter) |
|
orig_weight = torch.mm(orig_weight, delta_weight) |
|
orig_weight = orig_weight.view( |
|
self.out_features, |
|
self.in_features, |
|
self.base_layer.kernel_size[0], |
|
self.base_layer.kernel_size[0], |
|
) |
|
|
|
if not torch.isfinite(orig_weight).all(): |
|
raise ValueError( |
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" |
|
) |
|
|
|
self.base_layer.weight.data = orig_weight |
|
else: |
|
orig_weight = base_layer.weight.data |
|
orig_weight = orig_weight.view( |
|
self.out_features, |
|
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], |
|
) |
|
delta_weight = self.get_delta_weight(active_adapter) |
|
orig_weight = torch.mm(orig_weight, delta_weight) |
|
orig_weight = orig_weight.view( |
|
self.out_features, |
|
self.in_features, |
|
self.base_layer.kernel_size[0], |
|
self.base_layer.kernel_size[0], |
|
) |
|
|
|
self.base_layer.weight.data = orig_weight |
|
self.merged_adapters.append(active_adapter) |
|
|
|
def unmerge(self) -> None: |
|
""" |
|
This method unmerges all merged adapter layers from the base weights. |
|
""" |
|
if not self.merged: |
|
warnings.warn("Already unmerged. Nothing to do.") |
|
return |
|
while len(self.merged_adapters) > 0: |
|
active_adapter = self.merged_adapters.pop() |
|
if active_adapter in self.hra_u.keys(): |
|
orig_weight = self.get_base_layer().weight.data.clone() |
|
orig_weight = orig_weight.view( |
|
self.out_features, |
|
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], |
|
) |
|
delta_weight = self.get_delta_weight(active_adapter, reverse=True) |
|
orig_weight = torch.mm(orig_weight, delta_weight) |
|
orig_weight = orig_weight.view( |
|
self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0] |
|
) |
|
|
|
self.get_base_layer().weight.data = orig_weight |
|
|
|
def get_delta_weight(self, adapter_name: str, reverse: bool = False) -> torch.Tensor: |
|
rank = self.hra_r[adapter_name] |
|
apply_GS = self.hra_apply_GS[adapter_name] |
|
opt_u = self.hra_u[adapter_name] |
|
shape = opt_u.shape |
|
|
|
if apply_GS: |
|
weight = [(opt_u[:, 0] / opt_u[:, 0].norm()).view(-1, 1)] |
|
for i in range(1, rank): |
|
ui = opt_u[:, i].view(-1, 1) |
|
for j in range(i): |
|
ui = ui - (weight[j].t() @ ui) * weight[j] |
|
weight.append((ui / ui.norm()).view(-1, 1)) |
|
weight = torch.cat(weight, dim=1) |
|
weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) - 2 * weight @ weight.t() |
|
|
|
else: |
|
opt_u = opt_u / opt_u.norm(dim=0) |
|
weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) |
|
if reverse: |
|
indices = range(rank - 1, -1, -1) |
|
else: |
|
indices = range(rank) |
|
|
|
for i in indices: |
|
ui = opt_u[:, i].view(-1, 1) |
|
weight = weight - 2 * weight @ ui @ ui.t() |
|
|
|
return weight |
|
|
|
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: |
|
previous_dtype = x.dtype |
|
|
|
if self.disable_adapters: |
|
if self.merged: |
|
self.unmerge() |
|
result = self.base_layer(x, *args, **kwargs) |
|
elif self.merged: |
|
result = self.base_layer(x, *args, **kwargs) |
|
else: |
|
new_weight = torch.eye( |
|
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], |
|
device=x.device, |
|
dtype=previous_dtype, |
|
) |
|
for active_adapter in self.active_adapters: |
|
if active_adapter not in self.hra_u.keys(): |
|
continue |
|
delta_weight = self.get_delta_weight(active_adapter) |
|
new_weight = torch.mm(new_weight, delta_weight) |
|
|
|
x = x.to(self.base_layer.weight.data.dtype) |
|
|
|
orig_weight = self.base_layer.weight.data |
|
orig_weight = orig_weight.view( |
|
self.out_features, |
|
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], |
|
) |
|
new_weight = torch.mm(orig_weight, new_weight) |
|
new_weight = new_weight.view( |
|
self.out_features, |
|
self.in_features, |
|
self.base_layer.kernel_size[0], |
|
self.base_layer.kernel_size[0], |
|
) |
|
|
|
result = F.conv2d( |
|
input=x, |
|
weight=new_weight, |
|
bias=self.base_layer.bias, |
|
padding=self.base_layer.padding[0], |
|
stride=self.base_layer.stride[0], |
|
) |
|
|
|
result = result.to(previous_dtype) |
|
return result |
|
|
|
def __repr__(self) -> str: |
|
rep = super().__repr__() |
|
return "hra." + rep |
|
|