|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class IA3Module(nn.Module): |
|
""" |
|
Hadamard product Implementaion for Low Rank Adaptation |
|
""" |
|
|
|
def __init__( |
|
self, |
|
lora_name, |
|
org_module: nn.Module, |
|
multiplier=1.0, |
|
train_on_input=False, |
|
**kwargs |
|
): |
|
""" if alpha == 0 or None, alpha is rank (no scaling). """ |
|
super().__init__() |
|
self.lora_name = lora_name |
|
self.cp=False |
|
|
|
self.shape = org_module.weight.shape |
|
if org_module.__class__.__name__ == 'Conv2d': |
|
in_dim = org_module.in_channels |
|
out_dim = org_module.out_channels |
|
if train_on_input: |
|
train_dim = in_dim |
|
else: |
|
train_dim = out_dim |
|
self.weight = nn.Parameter(torch.empty(1, train_dim, 1, 1)) |
|
else: |
|
in_dim = org_module.in_features |
|
out_dim = org_module.out_features |
|
if train_on_input: |
|
train_dim = in_dim |
|
else: |
|
train_dim = out_dim |
|
|
|
self.weight = nn.Parameter(torch.empty(train_dim)) |
|
|
|
|
|
torch.nn.init.constant_(self.weight, 0) |
|
|
|
self.multiplier = multiplier |
|
self.org_forward = None |
|
self.org_module = [org_module] |
|
self.grad_ckpt = False |
|
self.train_input = train_on_input |
|
self.register_buffer('on_input', torch.tensor(int(train_on_input))) |
|
|
|
def apply_to(self): |
|
self.org_forward = self.org_module[0].forward |
|
self.org_module[0].forward = self.forward |
|
|
|
@torch.enable_grad() |
|
def forward(self, x): |
|
if self.train_input: |
|
x = x * (1 + self.weight * self.multiplier) |
|
out = self.org_forward(x) |
|
dtype = out.dtype |
|
if not self.train_input: |
|
out = out * (1 + self.weight * self.multiplier) |
|
out = out.to(dtype) |
|
return out |