File size: 2,318 Bytes
4783804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import torch
from safetensors import safe_open

class ExLlamaV2ModuleWrapper:
    @classmethod
    def wrap(cls, model, load = True):
        for idx, module in enumerate(model.modules):
            if idx == 0 or idx >= (len(model.modules) - 2):
                continue
            model.modules[idx] = ExLlamaV2ModuleWrapper(model, module, idx)

        if not load:
            return

        suppress_dir_file = os.path.join(model.config.model_dir, 'suppress_dir.safetensors')
        if os.path.exists(suppress_dir_file):
            print(f'Loading suppress direction file "{suppress_dir_file}"')
            with safe_open(suppress_dir_file, framework='pt', device='cpu') as f:
                model._suppress_dir = []
                for layer in range(len(f.keys())):
                    model._suppress_dir.append(f.get_tensor(f'_suppress_dir_{layer}'))
        else:
            print(f'No suppress direction file, not wrapping. Tried to load: "{suppress_dir_file}"')
            return

    def __init__(self, model, module, idx):
        if not hasattr(model, '_suppress_dir'):
            model._suppress_dir = None
        if not hasattr(model, '_residual'):
            model._residual = None
        self.model = model
        self.module = module
        self.idx = idx

    def __getattribute__(self, name):
        if name == 'forward':
            return object.__getattribute__(self, 'wrapped_forward')

        try:
            return getattr(object.__getattribute__(self, 'module'), name)
        except AttributeError:
            pass
        return object.__getattribute__(self, name)

    def suppress(self, x):
        if self.model._suppress_dir is not None:
            r = self.model._suppress_dir[self.idx - 2].clone().to(x.device)
            r = r.view(-1, 1)
            proj_scalar = torch.matmul(x, r)
            proj = proj_scalar * r.transpose(0, 1)
            x = x - proj
        return x

    def wrapped_forward(self, *args, **kwargs):
        if self.model._residual is not None:
            if len(self.model._residual) < self.idx and args[0].shape[1] == 1:
                self.model._residual.append(args[0].clone().to('cpu'))
        x = self.suppress(args[0])
        x = self.module.forward(*((x,) + args[1:]), **kwargs)
        return self.suppress(x)