File size: 2,036 Bytes
ac2ea1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class IA3Module(nn.Module):
    """
    Hadamard product Implementaion for Low Rank Adaptation
    """

    def __init__(
        self, 
        lora_name, 
        org_module: nn.Module, 
        multiplier=1.0, 
        train_on_input=False, 
        **kwargs
    ):
        """ if alpha == 0 or None, alpha is rank (no scaling). """
        super().__init__()
        self.lora_name = lora_name
        self.cp=False
        
        self.shape = org_module.weight.shape
        if org_module.__class__.__name__ == 'Conv2d':
            in_dim = org_module.in_channels
            out_dim = org_module.out_channels
            if train_on_input:
                train_dim = in_dim
            else:
                train_dim = out_dim
            self.weight = nn.Parameter(torch.empty(1, train_dim, 1, 1))
        else:
            in_dim = org_module.in_features
            out_dim = org_module.out_features
            if train_on_input:
                train_dim = in_dim
            else:
                train_dim = out_dim
            
            self.weight = nn.Parameter(torch.empty(train_dim))

        # Need more experiences on init method
        torch.nn.init.constant_(self.weight, 0)

        self.multiplier = multiplier
        self.org_forward = None
        self.org_module = [org_module] # remove in applying
        self.grad_ckpt = False
        self.train_input = train_on_input
        self.register_buffer('on_input', torch.tensor(int(train_on_input)))

    def apply_to(self):
        self.org_forward = self.org_module[0].forward
        self.org_module[0].forward = self.forward

    @torch.enable_grad()
    def forward(self, x):
        if self.train_input:
            x = x * (1 + self.weight * self.multiplier)
        out = self.org_forward(x)
        dtype = out.dtype
        if not self.train_input:
            out = out * (1 + self.weight * self.multiplier)
            out = out.to(dtype)
        return out