File size: 2,036 Bytes
ac2ea1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class IA3Module(nn.Module):
"""
Hadamard product Implementaion for Low Rank Adaptation
"""
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
train_on_input=False,
**kwargs
):
""" if alpha == 0 or None, alpha is rank (no scaling). """
super().__init__()
self.lora_name = lora_name
self.cp=False
self.shape = org_module.weight.shape
if org_module.__class__.__name__ == 'Conv2d':
in_dim = org_module.in_channels
out_dim = org_module.out_channels
if train_on_input:
train_dim = in_dim
else:
train_dim = out_dim
self.weight = nn.Parameter(torch.empty(1, train_dim, 1, 1))
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
if train_on_input:
train_dim = in_dim
else:
train_dim = out_dim
self.weight = nn.Parameter(torch.empty(train_dim))
# Need more experiences on init method
torch.nn.init.constant_(self.weight, 0)
self.multiplier = multiplier
self.org_forward = None
self.org_module = [org_module] # remove in applying
self.grad_ckpt = False
self.train_input = train_on_input
self.register_buffer('on_input', torch.tensor(int(train_on_input)))
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
@torch.enable_grad()
def forward(self, x):
if self.train_input:
x = x * (1 + self.weight * self.multiplier)
out = self.org_forward(x)
dtype = out.dtype
if not self.train_input:
out = out * (1 + self.weight * self.multiplier)
out = out.to(dtype)
return out |