memyprokotow's picture
Upload TTCompressedBartForConditionGeneration
4cacee8
raw
history blame
6.39 kB
# Copied from rut5compressed/util.py of rut5compressed repository.
import logging
import re
from functools import wraps
from re import Pattern
from typing import Callable, Dict, Optional, Tuple
import numpy as np
import torch as T
from .modules import TTCompressedLinear
def map_module(root: T.nn.Module,
func: Callable[[T.nn.Module, str], T.nn.Module],
patt: Optional[str] = None) -> T.nn.Module:
"""Function ``map_module`` applies a function to each leaf of module tree
which matches to a specified pattern.
Parameters
----------
root : torch.nn.Module
Module to modify.
func : callable
Function to be applied to every module (or matched to pattern) in
module tree.
patt : str, optional
Pattern to filter modules by path in module tree.
Returns
-------
torch.nn.Module
Module modified in-place.
"""
@wraps(func)
def func_safe(*args, **kwargs):
node = func(*args, **kwargs)
if not isinstance(node, T.nn.Module):
raise ValueError('Mapped result must be toch.nn.Module type '
f'but given {type(node)}.')
return node
return _map_module(root, func_safe, re.compile(patt or r'.*'), '')
def _map_module(root: T.nn.Module,
func: Callable[[T.nn.Module, str], T.nn.Module], patt: Pattern,
path: str) -> T.nn.Module:
for name, child in root.named_children():
node = _map_module(child, func, patt, f'{path}/{name}')
if node != child:
setattr(root, name, node)
if patt.match(path or '/'):
root = func(root, path or '/')
return root
def convert_linear(module: T.nn.Linear, ctor, **kwargs) -> T.nn.Module:
"""Function convert_linear takes module and returns linear module with
approximate matmul. Non-linear modules are returned intact.
"""
if not isinstance(module, T.nn.Linear):
return module
raise NotImplementedError
def numel(module: T.nn.Module):
value = sum(x.numel() for x in module.parameters()) + \
sum(x.numel() for x in module.buffers())
def account_prunned(module: T.nn.Module, path: str):
nonlocal value
for name, attr in vars(module).items():
if not name.endswith('_mask') or not isinstance(attr, T.Tensor):
continue
weight_name = name[:-5]
if not hasattr(module, weight_name):
continue
weight = getattr(module, weight_name)
value -= weight.numel() - attr.sum()
value += attr.numel()
return module
def account_quantized(module: T.nn.Module, path: str):
nonlocal value
if isinstance(module, T.nn.quantized.Linear):
value += module.weight().numel()
if module.bias() is not None:
value += module.bias().numel()
return module
def account_rest(module: T.nn.Module, path: str):
account_prunned(module, path)
account_quantized(module, path)
return module
map_module(module, account_rest)
return value
def sizeof(module: T.nn.Module):
value = sum(x.numel() * x.element_size() for x in module.parameters()) + \
sum(x.numel() * x.element_size() for x in module.buffers())
def account_prunned(module: T.nn.Module, path: str):
nonlocal value
for name, attr in vars(module).items():
if not name.endswith('_mask') or not isinstance(attr, T.Tensor):
continue
weight_name = name[:-5]
if not hasattr(module, weight_name):
continue
weight = getattr(module, weight_name)
value -= (weight.numel() - attr.sum()) * weight.element_size()
value += attr.numel() * attr.element_size()
return module
def account_quantized(module: T.nn.Module, path: str):
nonlocal value
if isinstance(module, T.nn.quantized.Linear):
value += module.weight().numel() * module.weight().element_size()
if (bias := module.bias()) is not None:
value += bias.numel() * bias.element_size()
return module
def account_rest(module: T.nn.Module, path: str):
account_prunned(module, path)
account_quantized(module, path)
return module
map_module(module, account_rest)
return value
def flatten_module(module: T.nn.Module, regexp=None) -> Dict[str, T.nn.Module]:
modules = {}
map_module(module, lambda x, y: modules.update(**{y: x}) or x, regexp)
return modules
def print_flatten(module: T.nn.Module):
paths = []
path_len = 0
names = []
name_len = 0
indx_len = 0
def func(module, path):
nonlocal path_len, name_len, indx_len
paths.append(path)
path_len = max(path_len, len(path))
name = module.__class__.__name__
names.append(name)
name_len = max(name_len, len(name))
indx_len += 1
return module
map_module(module, func)
indx_len = int(np.ceil(np.log10(indx_len)))
fmt = f'{{indx:>{indx_len}s}} {{path:{path_len}s}} {{name:{name_len}s}}'
print(fmt.format(indx='#', path='Path', name='Layer'))
print('-' * (indx_len + path_len + name_len + 2))
for i, (path, name) in enumerate(zip(paths, names)):
print(fmt.format(indx=str(i), path=path, name=name))
def compress_linear_tt(module: T.nn.Module, path: str,
shape: Tuple[Tuple[int], Tuple[int]],
rank: int) -> T.nn.Module:
if not isinstance(module, T.nn.Linear):
return module
# TODO(@not-found): We need propper compression config.
inp_size = np.prod(shape[0])
out_size = np.prod(shape[1])
if inp_size == module.in_features and out_size == module.out_features:
pass
elif inp_size == module.out_features and out_size == module.in_features:
shape = (shape[1], shape[0])
else:
raise ValueError(
'Input and output features does not match to compression shape: '
f'{shape[0]} vs {module.in_features} and {shape[1]} vs '
f'{module.out_features}.')
logging.info('apply tt compression to layer %s', path)
return TTCompressedLinear.from_linear(module, shape, rank) # noqa: F821