Yw22's picture
init demo
d711508
raw
history blame
15.6 kB
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Set, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from peft.tuners.lycoris_utils import LycorisLayer
class LoHaLayer(nn.Module, LycorisLayer):
# All names of layers that may contain adapter weights
adapter_layer_names = ("hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b", "hada_t1", "hada_t2")
# other_param_names is defined on parent class
def __init__(self, base_layer: nn.Module):
super().__init__()
LycorisLayer.__init__(self, base_layer)
# LoHa info
self.hada_w1_a = nn.ParameterDict({})
self.hada_w1_b = nn.ParameterDict({})
self.hada_w2_a = nn.ParameterDict({})
self.hada_w2_b = nn.ParameterDict({})
self.hada_t1 = nn.ParameterDict({})
self.hada_t2 = nn.ParameterDict({})
@property
def _available_adapters(self) -> Set[str]:
return {*self.hada_w1_a, *self.hada_w1_b, *self.hada_w2_a, *self.hada_w2_b, *self.hada_t1, *self.hada_t2}
def create_adapter_parameters(self, adapter_name: str, r: int, shape: Tuple[int, ...]):
# https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L130C9-L143C75
if len(shape) == 4:
self.hada_t1[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], shape[3]))
self.hada_w1_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode
self.hada_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode
self.hada_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], shape[3]))
self.hada_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode
self.hada_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode
else:
self.hada_w1_a[adapter_name] = nn.Parameter(torch.empty(shape[0], r))
self.hada_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1]))
self.hada_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0], r))
self.hada_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1]))
def reset_adapter_parameters(self, adapter_name: str):
# Original implementation performs initialization with normal distribution
# https://github.com/KohakuBlueleaf/LyCORIS/blob/3549fdef8f564761d68b695a08ef88b1122fdedc/lycoris/modules/loha.py#L158
# FedPara paper proposes to perform He initialization, let's stick with it
# It is enough to initialize only single matrix with zeros to make adapter do nothing after initialization
if adapter_name in self.hada_w1_a.keys():
nn.init.kaiming_uniform_(self.hada_w1_a[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.hada_w1_b[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.hada_w2_a[adapter_name], a=math.sqrt(5))
nn.init.zeros_(self.hada_w2_b[adapter_name])
if adapter_name in self.hada_t1.keys():
nn.init.kaiming_uniform_(self.hada_t1[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.hada_t2[adapter_name], a=math.sqrt(5))
def reset_adapter_parameters_random(self, adapter_name: str):
# Original implementation performs initialization with normal distribution
# https://github.com/KohakuBlueleaf/LyCORIS/blob/3549fdef8f564761d68b695a08ef88b1122fdedc/lycoris/modules/loha.py#L158
# FedPara paper proposes to perform He initialization, let's stick with it
# It is enough to initialize only single matrix with zeros to make adapter do nothing after initialization
if adapter_name in self.hada_w1_a.keys():
nn.init.kaiming_uniform_(self.hada_w1_a[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.hada_w1_b[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.hada_w2_a[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.hada_w2_b[adapter_name], a=math.sqrt(5))
if adapter_name in self.hada_t1.keys():
nn.init.kaiming_uniform_(self.hada_t1[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.hada_t2[adapter_name], a=math.sqrt(5))
def update_layer(
self,
adapter_name: str,
r: int,
alpha: float,
rank_dropout: float,
module_dropout: float,
init_weights: bool,
use_effective_conv2d: bool = False,
**kwargs,
) -> None:
"""Internal function to create loha adapter
Args:
adapter_name (`str`): Name for the adapter to add.
r (`int`): Rank for the added adapter.
alpha (`float`): Alpha for the added adapter.
rank_dropout (`float`): The dropout probability for rank dimension during training.
module_dropout (`float`): The dropout probability for disabling adapter during training.
init_weights (`bool`): Whether to initialize weights.
use_effective_conv2d (`bool`, *optional*, defaults to `False`):
Use parameter effective decomposition for Conv2d with ksize > 1.
"""
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
self.alpha[adapter_name] = alpha
self.scaling[adapter_name] = alpha / r
self.rank_dropout[adapter_name] = rank_dropout
self.module_dropout[adapter_name] = module_dropout
# Determine shape of LoHa weights
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
shape = tuple(base_layer.weight.shape)
elif isinstance(base_layer, nn.Conv2d):
use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1)
if use_effective_conv2d:
shape = (base_layer.out_channels, base_layer.in_channels, *base_layer.kernel_size)
else:
shape = (
base_layer.out_channels,
base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
)
else:
raise TypeError(f"LoHa is not implemented for base layers of type {type(base_layer).__name__}")
# Create weights with provided shape
self.create_adapter_parameters(adapter_name, r, shape)
# Initialize weights
if init_weights:
self.reset_adapter_parameters(adapter_name)
else:
self.reset_adapter_parameters_random(adapter_name)
# Move new weights to device
weight = getattr(self.get_base_layer(), "weight", None)
if weight is not None:
# the layer is already completely initialized, this is an update
if weight.dtype.is_floating_point or weight.dtype.is_complex:
self.to(weight.device, dtype=weight.dtype)
else:
self.to(weight.device)
self.set_adapter(self.active_adapters)
def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
# https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L178
if adapter_name in self.hada_t1.keys():
weight = make_weight_cp(
self.hada_t1[adapter_name],
self.hada_w1_a[adapter_name],
self.hada_w1_b[adapter_name],
self.hada_t2[adapter_name],
self.hada_w2_a[adapter_name],
self.hada_w2_b[adapter_name],
scale=torch.tensor(self.scaling[adapter_name]),
)
else:
weight = make_weight(
self.hada_w1_a[adapter_name],
self.hada_w1_b[adapter_name],
self.hada_w2_a[adapter_name],
self.hada_w2_b[adapter_name],
scale=torch.tensor(self.scaling[adapter_name]),
)
base_layer = self.get_base_layer()
weight = weight.reshape(base_layer.weight.shape)
# Perform rank dropout during training - drop rows of addition weights
rank_dropout = self.rank_dropout[adapter_name]
if self.training and rank_dropout:
drop = (torch.rand(weight.size(0)) > rank_dropout).to(weight.dtype)
drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device)
# TODO: Investigate if there should be a scaler like in normal dropout during training
# Original implementation doesn't have it
# https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L193
drop /= drop.mean()
weight *= drop
return weight
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
previous_dtype = x.dtype
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
# Execute all the adapters
for active_adapter in self.active_adapters:
if active_adapter not in self._available_adapters:
continue
module_dropout = self.module_dropout[active_adapter]
# Modify current execution weights
if (not self.training) or (self.training and torch.rand(1) > module_dropout):
result = result + self._get_delta_activations(active_adapter, x, *args, **kwargs)
result = result.to(previous_dtype)
return result
class Linear(LoHaLayer):
"""LoHa implemented in Linear layer"""
def __init__(
self,
base_layer: nn.Module,
adapter_name: str = "default",
r: int = 0,
alpha: float = 0.0,
rank_dropout: float = 0.0,
module_dropout: float = 0.0,
init_weights: bool = True,
**kwargs,
):
super().__init__(base_layer)
# Create adapter and set it active
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, **kwargs)
def _get_delta_activations(
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
) -> torch.Tensor:
delta_weight = self.get_delta_weight(adapter_name)
# don't add bias here, because the bias is already included in the output of the base_layer
return F.linear(input, delta_weight)
def __repr__(self) -> str:
rep = super().__repr__()
return "loha." + rep
class Conv2d(LoHaLayer):
"""LoHa implemented in Conv2d layer"""
def __init__(
self,
base_layer: nn.Module,
adapter_name: str = "default",
r: int = 0,
alpha: float = 0.0,
rank_dropout: float = 0.0,
module_dropout: float = 0.0,
use_effective_conv2d: bool = False,
init_weights: bool = True,
**kwargs,
):
super().__init__(base_layer)
# Create adapter and set it active
self._active_adapter = adapter_name
self.update_layer(
adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs
)
def _get_delta_activations(
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
) -> torch.Tensor:
delta_weight = self.get_delta_weight(adapter_name)
# don't add bias here, because the bias is already included in the output of the base_layer
base_layer = self.get_base_layer()
return F.conv2d(
input,
delta_weight,
stride=base_layer.stride,
padding=base_layer.padding,
dilation=base_layer.dilation,
groups=base_layer.groups,
)
def __repr__(self) -> str:
rep = super().__repr__()
return "loha." + rep
# Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L9
class HadaWeight(torch.autograd.Function):
@staticmethod
def forward(ctx, w1a, w1b, w2a, w2b, scale=torch.tensor(1)):
ctx.save_for_backward(w1a, w1b, w2a, w2b, scale)
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
return diff_weight
@staticmethod
def backward(ctx, grad_out):
(w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = grad_out * (w2a @ w2b)
grad_w1a = temp @ w1b.T
grad_w1b = w1a.T @ temp
temp = grad_out * (w1a @ w1b)
grad_w2a = temp @ w2b.T
grad_w2b = w2a.T @ temp
del temp
return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
class HadaWeightCP(torch.autograd.Function):
@staticmethod
def forward(ctx, t1, w1a, w1b, t2, w2a, w2b, scale=torch.tensor(1)):
ctx.save_for_backward(t1, w1a, w1b, t2, w2a, w2b, scale)
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", t1, w1b, w1a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", t2, w2b, w2a)
return rebuild1 * rebuild2 * scale
@staticmethod
def backward(ctx, grad_out):
(t1, w1a, w1b, t2, w2a, w2b, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = torch.einsum("i j k l, j r -> i r k l", t2, w2b)
rebuild = torch.einsum("i j k l, i r -> r j k l", temp, w2a)
grad_w = rebuild * grad_out
del rebuild
grad_w1a = torch.einsum("r j k l, i j k l -> r i", temp, grad_w)
grad_temp = torch.einsum("i j k l, i r -> r j k l", grad_w, w1a.T)
del grad_w, temp
grad_w1b = torch.einsum("i r k l, i j k l -> r j", t1, grad_temp)
grad_t1 = torch.einsum("i j k l, j r -> i r k l", grad_temp, w1b.T)
del grad_temp
temp = torch.einsum("i j k l, j r -> i r k l", t1, w1b)
rebuild = torch.einsum("i j k l, i r -> r j k l", temp, w1a)
grad_w = rebuild * grad_out
del rebuild
grad_w2a = torch.einsum("r j k l, i j k l -> r i", temp, grad_w)
grad_temp = torch.einsum("i j k l, i r -> r j k l", grad_w, w2a.T)
del grad_w, temp
grad_w2b = torch.einsum("i r k l, i j k l -> r j", t2, grad_temp)
grad_t2 = torch.einsum("i j k l, j r -> i r k l", grad_temp, w2b.T)
del grad_temp
return grad_t1, grad_w1a, grad_w1b, grad_t2, grad_w2a, grad_w2b, None
def make_weight(w1a, w1b, w2a, w2b, scale):
return HadaWeight.apply(w1a, w1b, w2a, w2b, scale)
def make_weight_cp(t1, w1a, w1b, t2, w2a, w2b, scale):
return HadaWeightCP.apply(t1, w1a, w1b, t2, w2a, w2b, scale)