khoicrtp's picture
init
12001a9
# Derived from https://github.com/microsoft/LoRA
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Dict, List
import lit_llama.model as llama
from contextlib import contextmanager
from dataclasses import dataclass
class LoRALayer():
def __init__(
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights
class MergedLinear(nn.Linear, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
enable_lora: List[bool] = [False],
fan_in_fan_out: bool = False,
merge_weights: bool = True,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
merge_weights=merge_weights)
assert out_features % len(enable_lora) == 0, \
'The length of enable_lora must divide out_features'
self.enable_lora = enable_lora
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0 and any(enable_lora):
self.lora_A = nn.Parameter(
self.weight.new_zeros((r * sum(enable_lora), in_features)))
self.lora_B = nn.Parameter(
self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
) # weights for Conv1D with groups=sum(enable_lora)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
# Compute the indices
self.lora_ind = self.weight.new_zeros(
(out_features, ), dtype=torch.bool
).view(len(enable_lora), -1)
self.lora_ind[enable_lora, :] = True
self.lora_ind = self.lora_ind.view(-1)
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def zero_pad(self, x):
x = x.transpose(0, 1)
result = x.new_zeros((*x.shape[:-1], self.out_features))
result = result.view(-1, self.out_features)
result[:, self.lora_ind] = x.reshape(
-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
)
return result.view((*x.shape[:-1], self.out_features)).transpose(0, 1)
def train(self, mode: bool = True):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
# if train(True) -> unmerge unless we already have them unmerged
# if train(False) -> merge unless we already have them merged
should = self.merged if mode else not self.merged
if self.merge_weights and should:
if self.r > 0 and any(self.enable_lora):
delta_w = F.conv1d(
self.lora_A.data.unsqueeze(0),
self.lora_B.data.unsqueeze(-1),
groups=sum(self.enable_lora)
).squeeze(0)
# -1: W = W - delta_W (unmerge), +1: W = W + delta_W (merge)
sign = -1 if mode else 1
self.weight.data += sign * self.zero_pad(T(delta_w * self.scaling))
self.merged = not mode
def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
if self.merged:
return F.linear(x, T(self.weight), bias=self.bias)
else:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
after_A = F.linear(self.lora_dropout(x), self.lora_A)
after_B = F.conv1d(
after_A.transpose(-2, -1),
self.lora_B.unsqueeze(-1),
groups=sum(self.enable_lora)
).transpose(-2, -1)
result += self.zero_pad(after_B) * self.scaling
return result
def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
for n, p in model.named_parameters():
if 'lora_' not in n:
p.requires_grad = False
if bias == 'none':
return
elif bias == 'all':
for n, p in model.named_parameters():
if 'bias' in n:
p.requires_grad = True
elif bias == 'lora_only':
for m in model.modules():
if isinstance(m, LoRALayer) and \
hasattr(m, 'bias') and \
m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError
def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
my_state_dict = model.state_dict()
if bias == 'none':
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
elif bias == 'all':
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
elif bias == 'lora_only':
to_return = {}
for k in my_state_dict:
if 'lora_' in k:
to_return[k] = my_state_dict[k]
bias_name = k.split('lora_')[0]+'bias'
if bias_name in my_state_dict:
to_return[bias_name] = my_state_dict[bias_name]
return to_return
else:
raise NotImplementedError
@dataclass
class LoRAConfig:
r: float = 0.0
alpha: float = 1.0
dropout: float = 0.0
class CausalSelfAttention(llama.CausalSelfAttention):
lora_config = None
def __init__(self, config: llama.LLaMAConfig) -> None:
# Skip the parent class __init__ altogether and replace it to avoid
# useless allocations
nn.Module.__init__(self)
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = MergedLinear(
in_features=config.n_embd,
out_features=3 * config.n_embd,
r=self.lora_config.r,
lora_alpha=self.lora_config.alpha,
lora_dropout=self.lora_config.dropout,
enable_lora=[True, False, True],
fan_in_fan_out = False,
merge_weights=True,
bias=False)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
# regularization
self.n_head = config.n_head
self.n_embd = config.n_embd
self.block_size = config.block_size
self.rope_cache = None
@contextmanager
def lora(r, alpha, dropout, enabled: bool = True):
"""A context manager under which you can instantiate the model with LoRA."""
if not enabled:
yield
return
CausalSelfAttention.lora_config = LoRAConfig(r=r, alpha=alpha, dropout=dropout)
causal_self_attention = llama.CausalSelfAttention
llama.CausalSelfAttention = CausalSelfAttention
yield
llama.CausalSelfAttention = causal_self_attention
CausalSelfAttention.lora_config = None