File size: 6,048 Bytes
ac2ea1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import math
import random
from collections import OrderedDict, abc as container_abcs
import torch
import torch.nn as nn
import torch.nn.functional as F
class DyLoraModule(nn.Module):
"""
Hadamard product Implementaion for Dynamic Low Rank adaptation
"""
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4, alpha=1,
dropout=0.,
use_cp=False,
block_size=1,
**kwargs,
):
""" if alpha == 0 or None, alpha is rank (no scaling). """
super().__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
assert lora_dim % block_size == 0, 'lora_dim must be a multiple of block_size'
self.block_count = lora_dim//block_size
self.block_size = block_size
self.shape = org_module.weight.shape
if org_module.__class__.__name__ == 'Conv2d':
in_dim = org_module.in_channels
k_size = org_module.kernel_size
out_dim = org_module.out_channels
shape = (out_dim, in_dim*k_size[0]*k_size[1])
self.op = F.conv2d
self.extra_args = {
"stride": org_module.stride,
"padding": org_module.padding,
"dilation": org_module.dilation,
"groups": org_module.groups
}
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
shape = (out_dim, in_dim)
self.op = F.linear
self.extra_args = {}
self.lora_dim = lora_dim
self.up_list = nn.ParameterList([
torch.empty(shape[0], 1)
for i in range(lora_dim)
])
self.up_list.requires_grad_(False)
self.up_update = [
torch.zeros_like(self.up_list[i])
for i in range(lora_dim)
]
self.down_list = nn.ParameterList([
torch.empty(1, shape[1])
for i in range(lora_dim)
])
self.down_list.requires_grad_(False)
self.down_update = [
torch.zeros_like(self.down_list[i])
for i in range(lora_dim)
]
self.index = 0
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
# Need more experiences on init method
for v in self.down_list:
torch.nn.init.kaiming_uniform_(v, a=math.sqrt(5))
for v in self.up_list:
torch.nn.init.zeros_(v)
for i, v in enumerate(self.up_update):
v.copy_(self.up_list[i])
for i, v in enumerate(self.down_update):
v.copy_(self.down_list[i])
self.multiplier = multiplier
self.org_module = [org_module] # remove in applying
self.grad_ckpt = False
self.apply_train(0)
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
# TODO: Remove `args` and the parsing logic when BC allows.
if len(args) > 0:
if destination is None:
destination = args[0]
if len(args) > 1 and prefix == '':
prefix = args[1]
if len(args) > 2 and keep_vars is False:
keep_vars = args[2]
# DeprecationWarning is ignored by default
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
local_metadata = dict(version=self._version)
if hasattr(destination, "_metadata"):
destination._metadata[prefix[:-1]] = local_metadata
destination[f'{prefix}alpha'] = self.alpha
destination[f'{prefix}lora_up.weight'] = nn.Parameter(
torch.concat(self.up_update, dim=1)
)
destination[f'{prefix}lora_down.weight'] = nn.Parameter(
torch.concat(self.down_update)
)
return destination
def apply_to(self):
self.org_module[0].forward = self.forward
def apply_train(self, b:int):
self.up_list.requires_grad_(False)
self.down_list.requires_grad_(False)
for i in range(self.index*self.block_size, (self.index+1)*self.block_size):
self.up_update[i].copy_(self.up_list[i])
self.down_update[i].copy_(self.down_list[i])
for i in range(b*self.block_size, (b+1)*self.block_size):
self.up_list[i].copy_(self.up_update[i])
self.down_list[i].copy_(self.down_update[i])
self.up_list.requires_grad_(True)
self.down_list.requires_grad_(True)
self.index = b
@torch.enable_grad()
def forward(self, x):
b = random.randint(0, self.block_count-1)
if self.up_update[b].device != self.up_list[b].device:
device = self.up_list[b].device
for i in range(self.lora_dim):
self.up_update[i] = self.up_update[i].to(device)
self.down_update[i] = self.down_update[i].to(device)
if self.training:
self.apply_train(b)
down = torch.concat(
list(self.down_update[:b*self.block_size])
+ list(self.down_list[b*self.block_size:(b+1)*self.block_size])
)
up = torch.concat(
list(self.up_update[:b*self.block_size])
+ list(self.up_list[b*self.block_size:(b+1)*self.block_size]),
dim=1
)
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
return self.op(
x,
self.org_module[0].weight + (up@down).view(self.shape) * self.alpha/(b+1),
bias,
**self.extra_args
) |