|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: |
|
''' |
|
return a tuple of two value of input dimension decomposed by the number closest to factor |
|
second value is higher or equal than first value. |
|
|
|
In LoRA with Kroneckor Product, first value is a value for weight scale. |
|
secon value is a value for weight. |
|
|
|
Becuase of non-commutative property, AโB โ BโA. Meaning of two matrices is slightly different. |
|
|
|
examples) |
|
factor |
|
-1 2 4 8 16 ... |
|
127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 |
|
128 -> 16, 8 128 -> 64, 2 128 -> 32, 4 128 -> 16, 8 128 -> 16, 8 |
|
250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 |
|
360 -> 45, 8 360 -> 180, 2 360 -> 90, 4 360 -> 45, 8 360 -> 45, 8 |
|
512 -> 32, 16 512 -> 256, 2 512 -> 128, 4 512 -> 64, 8 512 -> 32, 16 |
|
1024 -> 32, 32 1024 -> 512, 2 1024 -> 256, 4 1024 -> 128, 8 1024 -> 64, 16 |
|
''' |
|
|
|
if factor > 0 and (dimension % factor) == 0: |
|
m = factor |
|
n = dimension // factor |
|
return m, n |
|
if factor == -1: |
|
factor = dimension |
|
m, n = 1, dimension |
|
length = m + n |
|
while m<n: |
|
new_m = m + 1 |
|
while dimension%new_m != 0: |
|
new_m += 1 |
|
new_n = dimension // new_m |
|
if new_m + new_n > length or new_m>factor: |
|
break |
|
else: |
|
m, n = new_m, new_n |
|
if m > n: |
|
n, m = m, n |
|
return m, n |
|
|
|
|
|
def make_weight_cp(t, wa, wb): |
|
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l', t, wa, wb) |
|
return rebuild2 |
|
|
|
|
|
def make_kron(orig_weight, w1, w2, scale): |
|
if len(w2.shape) == 4: |
|
w1 = w1.unsqueeze(2).unsqueeze(2) |
|
w2 = w2.contiguous() |
|
return orig_weight + torch.kron(w1, w2).reshape(orig_weight.shape)*scale |
|
|
|
|
|
class LokrModule(nn.Module): |
|
""" |
|
modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule |
|
and from KohakuBlueleaf/LyCORIS/lycoris:loha:LoHaModule |
|
and from KohakuBlueleaf/LyCORIS/lycoris:locon:LoconModule |
|
""" |
|
|
|
def __init__( |
|
self, |
|
lora_name, org_module: nn.Module, |
|
multiplier=1.0, |
|
lora_dim=4, alpha=1, |
|
dropout=0., |
|
use_cp=False, |
|
decompose_both = False, |
|
factor:int=-1, |
|
**kwargs, |
|
): |
|
""" if alpha == 0 or None, alpha is rank (no scaling). """ |
|
super().__init__() |
|
factor = int(factor) |
|
self.lora_name = lora_name |
|
self.lora_dim = lora_dim |
|
self.cp = False |
|
self.use_w1 = False |
|
self.use_w2 = False |
|
|
|
self.shape = org_module.weight.shape |
|
if org_module.__class__.__name__ == 'Conv2d': |
|
in_dim = org_module.in_channels |
|
k_size = org_module.kernel_size |
|
out_dim = org_module.out_channels |
|
|
|
in_m, in_n = factorization(in_dim, factor) |
|
out_l, out_k = factorization(out_dim, factor) |
|
shape = ((out_l, out_k), (in_m, in_n), *k_size) |
|
|
|
self.cp = use_cp and k_size!=(1, 1) |
|
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2: |
|
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim)) |
|
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0])) |
|
else: |
|
self.use_w1 = True |
|
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) |
|
|
|
if lora_dim >= max(shape[0][1], shape[1][1])/2: |
|
self.use_w2 = True |
|
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *k_size)) |
|
elif self.cp: |
|
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3])) |
|
self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, shape[0][1])) |
|
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) |
|
else: |
|
|
|
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim)) |
|
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]*shape[2]*shape[3])) |
|
|
|
|
|
self.op = F.conv2d |
|
self.extra_args = { |
|
"stride": org_module.stride, |
|
"padding": org_module.padding, |
|
"dilation": org_module.dilation, |
|
"groups": org_module.groups |
|
} |
|
|
|
else: |
|
in_dim = org_module.in_features |
|
out_dim = org_module.out_features |
|
|
|
in_m, in_n = factorization(in_dim, factor) |
|
out_l, out_k = factorization(out_dim, factor) |
|
shape = ((out_l, out_k), (in_m, in_n)) |
|
|
|
|
|
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2: |
|
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim)) |
|
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0])) |
|
else: |
|
self.use_w1 = True |
|
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) |
|
|
|
if lora_dim < max(shape[0][1], shape[1][1])/2: |
|
|
|
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim)) |
|
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) |
|
|
|
else: |
|
self.use_w2 = True |
|
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1])) |
|
|
|
self.op = F.linear |
|
self.extra_args = {} |
|
|
|
if dropout: |
|
self.dropout = nn.Dropout(dropout) |
|
else: |
|
self.dropout = nn.Identity() |
|
|
|
if isinstance(alpha, torch.Tensor): |
|
alpha = alpha.detach().float().numpy() |
|
alpha = lora_dim if alpha is None or alpha == 0 else alpha |
|
if self.use_w2 and self.use_w1: |
|
|
|
alpha = lora_dim |
|
self.scale = alpha / self.lora_dim |
|
self.register_buffer('alpha', torch.tensor(alpha)) |
|
|
|
if self.use_w2: |
|
torch.nn.init.constant_(self.lokr_w2, 0) |
|
else: |
|
if self.cp: |
|
torch.nn.init.normal_(self.lokr_t2, std=0.1) |
|
torch.nn.init.normal_(self.lokr_w2_a, std=1) |
|
torch.nn.init.constant_(self.lokr_w2_b, 0) |
|
|
|
if self.use_w1: |
|
torch.nn.init.normal_(self.lokr_w1, std=1) |
|
else: |
|
torch.nn.init.normal_(self.lokr_w1_a, std=1) |
|
torch.nn.init.normal_(self.lokr_w1_b, std=0.1) |
|
|
|
self.multiplier = multiplier |
|
self.org_module = [org_module] |
|
weight = make_kron( |
|
self.org_module[0].weight.data, |
|
self.lokr_w1 if self.use_w1 else [email protected]_w1_b, |
|
(self.lokr_w2 if self.use_w2 |
|
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp |
|
else [email protected]_w2_b), |
|
torch.tensor(self.multiplier * self.scale) |
|
) |
|
assert torch.sum(torch.isnan(weight)) == 0, "weight is nan" |
|
|
|
|
|
def apply_to(self): |
|
self.org_forward = self.org_module[0].forward |
|
self.org_module[0].forward = self.forward |
|
|
|
def forward(self, x): |
|
weight = make_kron( |
|
self.org_module[0].weight.data, |
|
self.lokr_w1 if self.use_w1 else [email protected]_w1_b, |
|
(self.lokr_w2 if self.use_w2 |
|
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp |
|
else [email protected]_w2_b), |
|
torch.tensor(self.multiplier * self.scale) |
|
) |
|
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data |
|
return self.op( |
|
x, |
|
weight.view(self.shape), |
|
bias, |
|
**self.extra_args |
|
) |
|
|