File size: 6,385 Bytes
4cacee8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Copied from rut5compressed/util.py of rut5compressed repository.

import logging
import re
from functools import wraps
from re import Pattern
from typing import Callable, Dict, Optional, Tuple

import numpy as np
import torch as T

from .modules import TTCompressedLinear


def map_module(root: T.nn.Module,
               func: Callable[[T.nn.Module, str], T.nn.Module],
               patt: Optional[str] = None) -> T.nn.Module:
    """Function ``map_module`` applies a function to each leaf of module tree
    which matches to a specified pattern.

    Parameters
    ----------
    root : torch.nn.Module
        Module to modify.
    func : callable
        Function to be applied to every module (or matched to pattern) in
        module tree.
    patt : str, optional
        Pattern to filter modules by path in module tree.

    Returns
    -------
    torch.nn.Module
        Module modified in-place.
    """
    @wraps(func)
    def func_safe(*args, **kwargs):
        node = func(*args, **kwargs)
        if not isinstance(node, T.nn.Module):
            raise ValueError('Mapped result must be toch.nn.Module type '
                             f'but given {type(node)}.')
        return node

    return _map_module(root, func_safe, re.compile(patt or r'.*'), '')


def _map_module(root: T.nn.Module,
                func: Callable[[T.nn.Module, str], T.nn.Module], patt: Pattern,
                path: str) -> T.nn.Module:
    for name, child in root.named_children():
        node = _map_module(child, func, patt, f'{path}/{name}')
        if node != child:
            setattr(root, name, node)
    if patt.match(path or '/'):
        root = func(root, path or '/')
    return root


def convert_linear(module: T.nn.Linear, ctor, **kwargs) -> T.nn.Module:
    """Function convert_linear takes module and returns linear module with
    approximate matmul. Non-linear modules are returned intact.
    """
    if not isinstance(module, T.nn.Linear):
        return module
    raise NotImplementedError


def numel(module: T.nn.Module):
    value = sum(x.numel() for x in module.parameters()) + \
            sum(x.numel() for x in module.buffers())

    def account_prunned(module: T.nn.Module, path: str):
        nonlocal value
        for name, attr in vars(module).items():
            if not name.endswith('_mask') or not isinstance(attr, T.Tensor):
                continue

            weight_name = name[:-5]
            if not hasattr(module, weight_name):
                continue

            weight = getattr(module, weight_name)
            value -= weight.numel() - attr.sum()
            value += attr.numel()
        return module

    def account_quantized(module: T.nn.Module, path: str):
        nonlocal value
        if isinstance(module, T.nn.quantized.Linear):
            value += module.weight().numel()
            if module.bias() is not None:
                value += module.bias().numel()
        return module

    def account_rest(module: T.nn.Module, path: str):
        account_prunned(module, path)
        account_quantized(module, path)
        return module

    map_module(module, account_rest)
    return value


def sizeof(module: T.nn.Module):
    value = sum(x.numel() * x.element_size() for x in module.parameters()) + \
            sum(x.numel() * x.element_size() for x in module.buffers())

    def account_prunned(module: T.nn.Module, path: str):
        nonlocal value
        for name, attr in vars(module).items():
            if not name.endswith('_mask') or not isinstance(attr, T.Tensor):
                continue

            weight_name = name[:-5]
            if not hasattr(module, weight_name):
                continue

            weight = getattr(module, weight_name)
            value -= (weight.numel() - attr.sum()) * weight.element_size()
            value += attr.numel() * attr.element_size()
        return module

    def account_quantized(module: T.nn.Module, path: str):
        nonlocal value
        if isinstance(module, T.nn.quantized.Linear):
            value += module.weight().numel() * module.weight().element_size()
            if (bias := module.bias()) is not None:
                value += bias.numel() * bias.element_size()
        return module

    def account_rest(module: T.nn.Module, path: str):
        account_prunned(module, path)
        account_quantized(module, path)
        return module

    map_module(module, account_rest)
    return value


def flatten_module(module: T.nn.Module, regexp=None) -> Dict[str, T.nn.Module]:
    modules = {}
    map_module(module, lambda x, y: modules.update(**{y: x}) or x, regexp)
    return modules


def print_flatten(module: T.nn.Module):
    paths = []
    path_len = 0
    names = []
    name_len = 0
    indx_len = 0

    def func(module, path):
        nonlocal path_len, name_len, indx_len
        paths.append(path)
        path_len = max(path_len, len(path))
        name = module.__class__.__name__
        names.append(name)
        name_len = max(name_len, len(name))
        indx_len += 1
        return module

    map_module(module, func)

    indx_len = int(np.ceil(np.log10(indx_len)))
    fmt = f'{{indx:>{indx_len}s}} {{path:{path_len}s}} {{name:{name_len}s}}'
    print(fmt.format(indx='#', path='Path', name='Layer'))
    print('-' * (indx_len + path_len + name_len + 2))
    for i, (path, name) in enumerate(zip(paths, names)):
        print(fmt.format(indx=str(i), path=path, name=name))


def compress_linear_tt(module: T.nn.Module, path: str,
                       shape: Tuple[Tuple[int], Tuple[int]],
                       rank: int) -> T.nn.Module:
    if not isinstance(module, T.nn.Linear):
        return module

    # TODO(@not-found): We need propper compression config.
    inp_size = np.prod(shape[0])
    out_size = np.prod(shape[1])
    if inp_size == module.in_features and out_size == module.out_features:
        pass
    elif inp_size == module.out_features and out_size == module.in_features:
        shape = (shape[1], shape[0])
    else:
        raise ValueError(
            'Input and output features does not match to compression shape: '
            f'{shape[0]} vs {module.in_features} and {shape[1]} vs '
            f'{module.out_features}.')

    logging.info('apply tt compression to layer %s', path)
    return TTCompressedLinear.from_linear(module, shape, rank)  # noqa: F821