File size: 2,259 Bytes
2f6628d
 
fd48f4d
2f6628d
fd48f4d
2f6628d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
###########################################################################
# NLP demo software by HyperbeeAI.                                        #
# Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. [email protected] #
###########################################################################
license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. [email protected]"
print("imported functions.py")
print(license_statement)
print("")

import torch, sys
import torch.nn as nn
from torch.autograd import Function

class Q_ud(Function):
    @staticmethod
    def forward(_, x, xb):
        factor = 2**(xb-1)
        return x.mul(factor).add(.5).floor().div(factor)

class Q_u(Function):
    @staticmethod
    def forward(_, x, xb):
        factor = 2**(8-xb)
        return x.mul(factor).add(.5).floor()

class Q_d(Function):
    @staticmethod
    def forward(_, x, xb):
        factor = 2**(xb-1)
        return x.div(factor).add(.5).floor()

class quantization(nn.Module):
    def __init__(self, xb = 8, mode='updown', wide=False):
        super().__init__()
        self.xb   = xb
        self.mode = mode
        self.wide = wide

    def forward(self, x):
        if(self.mode=='updown'):
            return Q_ud.apply(x, self.xb)
        elif(self.mode=='down'):
            if(self.wide):
                return Q_d.apply(x, self.xb - 5)
            else:
                return Q_d.apply(x, self.xb)
        elif(self.mode=='up'):
            return Q_u.apply(x, self.xb)
        else:
        	print('wrong quantization mode. exiting')
        	sys.exit()

class clamping_hw(nn.Module):
    def __init__(self, xb = 8, wide=False):
        super().__init__()
        if(wide):
            self.min_val = -2**(30-1)  
            self.max_val =  2**(30-1)-1
        else:
            self.min_val = -2**(xb-1)
            self.max_val =  2**(xb-1)-1

    def forward(self, x):
        return x.clamp(min=self.min_val, max=self.max_val)

###################################################
### Linear layer functional
def linear_functional(x, weight, bias, _stride, _padding):
    # dummy linear function that has same arguments as conv
    return nn.functional.linear(x, weight, bias)