|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import warnings |
|
from typing import Any, List, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers.pytorch_utils import Conv1D |
|
|
|
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge |
|
from peft.utils import transpose |
|
|
|
|
|
class IA3Layer(BaseTunerLayer): |
|
|
|
adapter_layer_names = ("ia3_l",) |
|
|
|
def __init__(self, base_layer: nn.Module, is_feedforward: bool, **kwargs) -> None: |
|
self.base_layer = base_layer |
|
self.ia3_l = nn.ParameterDict({}) |
|
|
|
self._disable_adapters = False |
|
self.merged_adapters = [] |
|
self.is_feedforward = is_feedforward |
|
|
|
base_layer = self.get_base_layer() |
|
if isinstance(base_layer, nn.Linear): |
|
in_features, out_features = base_layer.in_features, base_layer.out_features |
|
elif isinstance(base_layer, (nn.Conv2d, nn.Conv3d)): |
|
in_features, out_features = base_layer.in_channels, base_layer.out_channels |
|
elif isinstance(base_layer, nn.Embedding): |
|
in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim |
|
elif isinstance(base_layer, Conv1D): |
|
in_features, out_features = ( |
|
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape |
|
) |
|
else: |
|
raise ValueError(f"Unsupported layer type {type(base_layer)}") |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
|
|
def update_layer(self, adapter_name, init_ia3_weights): |
|
|
|
|
|
if self.is_feedforward: |
|
weight = torch.randn((1, self.in_features)) |
|
else: |
|
weight = torch.randn((self.out_features, 1)) |
|
self.ia3_l[adapter_name] = nn.Parameter(weight) |
|
if init_ia3_weights: |
|
self.reset_ia3_parameters(adapter_name) |
|
self._move_adapter_to_device_of_base_layer(adapter_name) |
|
self.set_adapter(self.active_adapters) |
|
|
|
def reset_ia3_parameters(self, adapter_name): |
|
if adapter_name in self.ia3_l.keys(): |
|
|
|
nn.init.constant_(self.ia3_l[adapter_name], 1.0) |
|
|
|
|
|
class Linear(nn.Module, IA3Layer): |
|
|
|
def __init__( |
|
self, |
|
base_layer: nn.Module, |
|
adapter_name: str, |
|
fan_in_fan_out: bool = False, |
|
is_feedforward: bool = False, |
|
is_target_conv_1d_layer: bool = False, |
|
init_ia3_weights: bool = True, |
|
**kwargs, |
|
) -> None: |
|
super().__init__() |
|
IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward) |
|
self.fan_in_fan_out = fan_in_fan_out |
|
self.is_target_conv_1d_layer = is_target_conv_1d_layer |
|
self._active_adapter = adapter_name |
|
self.update_layer(adapter_name, init_ia3_weights) |
|
|
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: |
|
""" |
|
Merge the active adapter weights into the base weights |
|
|
|
Args: |
|
safe_merge (`bool`, *optional*): |
|
If True, the merge operation will be performed in a copy of the original weights and check for NaNs |
|
before merging the weights. This is useful if you want to check if the merge operation will produce |
|
NaNs. Defaults to `False`. |
|
adapter_names (`List[str]`, *optional*): |
|
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults |
|
to `None`. |
|
""" |
|
adapter_names = check_adapters_to_merge(self, adapter_names) |
|
if not adapter_names: |
|
|
|
return |
|
|
|
for active_adapter in adapter_names: |
|
if active_adapter in self.ia3_l.keys(): |
|
base_layer = self.get_base_layer() |
|
ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out) |
|
orig_dtype = base_layer.weight.data.dtype |
|
if safe_merge: |
|
orig_weights = base_layer.weight.data |
|
orig_weights = torch.mul(orig_weights, ia3_l) |
|
|
|
if not torch.isfinite(orig_weights).all(): |
|
raise ValueError( |
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" |
|
) |
|
base_layer.weight.data = orig_weights.to(orig_dtype) |
|
else: |
|
base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_l).to(orig_dtype) |
|
|
|
if not self.is_feedforward and (base_layer.bias is not None): |
|
scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) |
|
orig_dtype = base_layer.bias.data.dtype |
|
base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data).to(orig_dtype) |
|
|
|
self.merged_adapters.append(active_adapter) |
|
|
|
def unmerge(self) -> None: |
|
""" |
|
This method unmerges all merged adapter layers from the base weights. |
|
""" |
|
if not self.merged: |
|
warnings.warn("Already unmerged. Nothing to do.") |
|
return |
|
|
|
warnings.warn("Unmerge result can be inaccurate for (IA)^3.") |
|
while len(self.merged_adapters) > 0: |
|
active_adapter = self.merged_adapters.pop() |
|
if active_adapter in self.ia3_l.keys(): |
|
base_layer = self.get_base_layer() |
|
|
|
ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out) + 1e-8 |
|
orig_dtype = base_layer.weight.data.dtype |
|
base_layer.weight.data = torch.div(base_layer.weight.data, ia3_l).to(orig_dtype) |
|
|
|
if not self.is_feedforward and (base_layer.bias is not None): |
|
scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) |
|
orig_dtype = base_layer.bias.data.dtype |
|
base_layer.bias.data = torch.div(base_layer.bias.data, scaling.data + 1e-8).to(orig_dtype) |
|
|
|
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: |
|
dtype = previous_dtype = x.dtype |
|
if self.disable_adapters: |
|
if self.merged: |
|
self.unmerge() |
|
result = self.base_layer(x, *args, **kwargs) |
|
elif self.merged: |
|
result = self.base_layer(x, *args, **kwargs) |
|
else: |
|
ia3_scaling = 1 |
|
for active_adapter in self.active_adapters: |
|
if active_adapter not in self.ia3_l.keys(): |
|
continue |
|
dtype = self.ia3_l[active_adapter].dtype |
|
ia3_scaling *= self.ia3_l[active_adapter].flatten() |
|
|
|
if self.is_feedforward: |
|
x = x.to(dtype) |
|
|
|
|
|
interm = (x * ia3_scaling).to(previous_dtype) |
|
result = self.base_layer(interm, *args, **kwargs) |
|
else: |
|
result = self.base_layer(x, *args, **kwargs) |
|
result_dtype = result.dtype |
|
result = (result * ia3_scaling).to(result_dtype) |
|
|
|
return result |
|
|
|
|
|
class _ConvNd(nn.Module, IA3Layer): |
|
def __init__( |
|
self, |
|
base_layer: nn.Module, |
|
adapter_name: str, |
|
fan_in_fan_out: bool = False, |
|
is_feedforward: bool = False, |
|
init_ia3_weights: bool = True, |
|
**kwargs, |
|
) -> None: |
|
super().__init__() |
|
IA3Layer.__init__(self, base_layer, is_feedforward=is_feedforward) |
|
self.fan_in_fan_out = fan_in_fan_out |
|
self._active_adapter = adapter_name |
|
self._kernel_dim = base_layer.weight.dim() |
|
|
|
self.update_layer(adapter_name, init_ia3_weights) |
|
|
|
def update_layer(self, adapter_name, init_ia3_weights): |
|
|
|
num_features = self.in_features if self.is_feedforward else self.out_features |
|
weights_size = (1, num_features) + (1,) * (self._kernel_dim - 2) |
|
weight = torch.randn(weights_size) |
|
self.ia3_l[adapter_name] = nn.Parameter(weight) |
|
if init_ia3_weights: |
|
self.reset_ia3_parameters(adapter_name) |
|
self._move_adapter_to_device_of_base_layer(adapter_name) |
|
self.set_adapter(self.active_adapters) |
|
|
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: |
|
""" |
|
Merge the active adapter weights into the base weights |
|
|
|
Args: |
|
safe_merge (`bool`, *optional*): |
|
If True, the merge operation will be performed in a copy of the original weights and check for NaNs |
|
before merging the weights. This is useful if you want to check if the merge operation will produce |
|
NaNs. Defaults to `False`. |
|
adapter_names (`List[str]`, *optional*): |
|
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults |
|
to `None`. |
|
""" |
|
adapter_names = check_adapters_to_merge(self, adapter_names) |
|
if not adapter_names: |
|
|
|
return |
|
|
|
for active_adapter in adapter_names: |
|
if active_adapter in self.ia3_l.keys(): |
|
base_layer = self.get_base_layer() |
|
ia3_scaling = self.ia3_l[active_adapter].data |
|
if not self.is_feedforward: |
|
ia3_scaling = ia3_scaling.transpose(0, 1) |
|
|
|
if safe_merge: |
|
output_weight = torch.mul(base_layer.weight.data, ia3_scaling).clone() |
|
|
|
if not torch.isfinite(output_weight).all(): |
|
raise ValueError( |
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" |
|
) |
|
|
|
base_layer.weight.data = output_weight |
|
else: |
|
base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_scaling) |
|
|
|
if not self.is_feedforward and (base_layer.bias is not None): |
|
scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) |
|
base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data) |
|
|
|
self.merged_adapters.append(active_adapter) |
|
|
|
def unmerge(self) -> None: |
|
""" |
|
This method unmerges all merged adapter layers from the base weights. |
|
""" |
|
if not self.merged: |
|
warnings.warn("Already unmerged. Nothing to do.") |
|
return |
|
|
|
warnings.warn("Unmerge result can be inaccurate for (IA)^3.") |
|
while len(self.merged_adapters) > 0: |
|
active_adapter = self.merged_adapters.pop() |
|
if active_adapter in self.ia3_l.keys(): |
|
base_layer = self.get_base_layer() |
|
|
|
ia3_scaling = self.ia3_l[active_adapter].data |
|
if not self.is_feedforward: |
|
ia3_scaling = ia3_scaling.transpose(0, 1) |
|
base_layer.weight.data = torch.div(base_layer.weight.data, ia3_scaling + 1e-8) |
|
|
|
if not self.is_feedforward and (base_layer.bias is not None): |
|
scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) |
|
base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data) |
|
|
|
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: |
|
dtype = previous_dtype = x.dtype |
|
|
|
if self.disable_adapters: |
|
if self.merged: |
|
self.unmerge() |
|
result = self.base_layer(x, *args, **kwargs) |
|
elif self.merged: |
|
result = self.base_layer(x, *args, **kwargs) |
|
else: |
|
ia3_scaling = 1 |
|
for active_adapter in self.active_adapters: |
|
if active_adapter not in self.ia3_l.keys(): |
|
continue |
|
dtype = self.ia3_l[active_adapter].dtype |
|
ia3_scaling *= self.ia3_l[active_adapter] |
|
|
|
if self.is_feedforward: |
|
x = x.to(dtype) |
|
|
|
|
|
interm = (x * ia3_scaling).to(self.get_base_layer().weight.dtype) |
|
result = self.base_layer(interm, *args, **kwargs) |
|
else: |
|
result = self.base_layer(x, *args, **kwargs) |
|
result = result.to(dtype) * ia3_scaling |
|
|
|
result = result.to(previous_dtype) |
|
return result |
|
|
|
|
|
class Conv2d(_ConvNd): |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
if not self._kernel_dim == 4: |
|
raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}") |
|
|
|
|
|
class Conv3d(_ConvNd): |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
if not self._kernel_dim == 5: |
|
raise ValueError(f"Conv2d layer kernel must have 5 dimensions, not {self._kernel_dim}") |
|
|