import timm | |
from timm.models._factory import load_checkpoint | |
import torch | |
import os | |
from torch import nn | |
from torch.jit import Final | |
from einops import rearrange, repeat | |
from einops.layers.torch import Rearrange | |
from utils.dl.common.model import get_model_device, set_module | |
import torch.nn.functional as F | |
from utils.common.log import logger | |
# class SoftmaxIgnoringZero(nn.Module): | |
# def __init__(self): | |
# super(SoftmaxIgnoringZero, self).__init__() | |
# def forward(self, x: torch.Tensor): | |
# # non_zero_x_indexes = x.nonzero(as_tuple=True)[0] | |
# # non_zero_x = x[non_zero_x_indexes] | |
# # non_zero_x_softmax = F.softmax(non_zero_x, self.dim, _stacklevel=5) | |
# # res = torch.zeros_like(x) | |
# # original: e^i / \sum_i e^i | |
# # ignoring zero: e^i | |
# # print(x) | |
# non_zero_mask = x != 0 | |
# if non_zero_mask.sum() == x.numel(): | |
# return F.softmax(x, -1) | |
# t = non_zero_mask.sum(-1) | |
# assert t.view(-1).unique().size(0) == 1, f'{t.view(-1).unique()}, {x.size()}' # all vectors in the softmaxed dim has the same number of 0 | |
# # assert t.view(-1).unique().size(0) <= 2, f'{t.view(-1).unique()}, {x.size()}' # all vectors in the softmaxed dim has the same number of 0 or has no 0 | |
# non_zero_x = torch.masked_select(x, non_zero_mask) | |
# non_zero_x = non_zero_x.view(*(list(x.size())[0: -1] + [t.view(-1)[0].item()])) | |
# # print(non_zero_x) | |
# non_zero_x_softmax = F.softmax(non_zero_x, -1) | |
# a = x.nonzero(as_tuple=True)[-1] | |
# a = a.view(*non_zero_x_softmax.size()) | |
# x = x.scatter(x.dim() - 1, a, non_zero_x_softmax) | |
# return x | |
class SoftmaxIgnoringZero(nn.Module): | |
def __init__(self): | |
super(SoftmaxIgnoringZero, self).__init__() | |
def f(self, x): | |
# return x / (x + 1e-8) | |
return 1. | |
def forward(self, x: torch.Tensor): | |
res = F.softmax(x, -1) | |
return res * self.f(x) | |
class PrunableAttention(nn.Module): | |
""" | |
https://github.com/lucidrains/vit-pytorch | |
""" | |
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qkv_bias = False): | |
super().__init__() | |
self.inner_dim = inner_dim = dim_head * heads | |
project_out = not (heads == 1 and dim_head == dim) | |
self.num_heads = heads | |
self.scale = dim_head ** -0.5 | |
self.attend = nn.Softmax(dim = -1) | |
self.dropout = nn.Dropout(dropout) | |
self.qkv = nn.Linear(dim, inner_dim * 3, bias = qkv_bias) | |
# self.proj = nn.Sequential( | |
# nn.Linear(inner_dim, dim), | |
# nn.Dropout(dropout) | |
# ) if project_out else nn.Identity() | |
self.proj = nn.Linear(inner_dim, dim) if project_out else nn.Identity() | |
self.proj_dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
# qkv = self.qkv(x).chunk(3, dim = -1) | |
raw_qkv = self.qkv(x) | |
self.inner_dim = (raw_qkv.size(-1) - self.proj.in_features) // 2 | |
qkv = raw_qkv[:, :, 0: self.inner_dim], raw_qkv[:, :, self.inner_dim: self.inner_dim * 2], raw_qkv[:, :, self.inner_dim * 2:] | |
# print('v', qkv[0].size(), qkv[0].sum((0, 1))[0: 10], qkv[0].sum((0, 1)).nonzero(as_tuple=True)[0].size()) | |
# raw_v = qkv[2] | |
# print('after_fbs_q, after_fbs_k', qkv[0].sum((0, 1))[0: 10], qkv[0].sum((0, 1)).nonzero(as_tuple=True)[0].size(), | |
# qkv[1].sum((0, 1))[0: 10], qkv[1].sum((0, 1)).nonzero(as_tuple=True)[0].size(),) | |
# print('after_fbs_v', raw_v.size(), raw_v.sum((0, 1))[0: 10], raw_v.sum((0, 1)).nonzero(as_tuple=True)[0].size()) | |
# print('q, before rearrage', qkv[0].size()) | |
q, k, v = qkv | |
# print('raw qkv size', q.size(), k.size(), v.size()) | |
# exit() | |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.num_heads), qkv) | |
# print('raw qkv size', q.size(), k.size(), v.size()) | |
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale | |
# print('q, k, dots, after rearrage', q.size(), k.transpose(-1, -2).size(), dots.size()) | |
attn = self.attend(dots) | |
# attn = dots | |
attn = self.dropout(attn) | |
# print(attn) | |
# print('attn', attn.size(), attn.sum((0, 1))[0: 10], attn.sum((0, 1)).nonzero(as_tuple=True)[0].size()) | |
# print('attn', attn.size(), attn.sum((0, 1))[0: 10], attn.sum((0, 1)).nonzero(as_tuple=True)[0].size()) | |
# print('v2', v.size()) | |
out = torch.matmul(attn, v) | |
# print('out1', out.size()) | |
# NOTE: just for trial debug | |
# out = v | |
# print('out before rerange', out.size()) | |
# print(v.size(), v) | |
# exit() | |
out = rearrange(out, 'b h n d -> b n (h d)') | |
# print('out', out.size(), out.sum((0, 1))[0: 10], out.sum((0, 1)).nonzero(as_tuple=True)[0].size()) | |
# exit() | |
res = self.proj_dropout(self.proj(out)) | |
# res = self.proj_dropout( | |
# F.linear(self.proj.weight.T, out.T, self.proj.bias) | |
# ) | |
# print(self.proj, self.proj_dropout) | |
# print('res', res.size(), res.sum((0, 1))[0: 10], res.sum((0, 1)).nonzero(as_tuple=True)[0].size()) | |
return res | |
def make_attention_prunable(vit): | |
for block in vit.blocks: | |
attn = block.attn | |
assert attn.attn_drop.p == attn.proj_drop.p | |
prunable_attn = PrunableAttention( | |
dim=attn.head_dim * attn.num_heads, | |
heads=attn.num_heads, | |
dim_head=attn.head_dim, | |
dropout=attn.attn_drop.p, | |
qkv_bias=attn.qkv.bias is not None | |
) | |
prunable_attn.qkv.weight.copy_(attn.qkv.weight) | |
if attn.qkv.bias is not None: | |
prunable_attn.qkv.bias.copy_(attn.qkv.bias) | |
prunable_attn.proj.weight.copy_(attn.proj.weight) | |
prunable_attn.proj.bias.copy_(attn.proj.bias) | |
set_module(block, 'attn', prunable_attn) | |
def vit_l_16(pretrained=True, num_classes=None) -> nn.Module: | |
# https://huggingface.co/timm/vit_large_patch16_224.augreg_in21k_ft_in1k | |
res = timm.create_model('vit_large_patch16_224.augreg_in21k_ft_in1k', | |
num_classes=num_classes) | |
if pretrained: | |
checkpoint_path = os.path.join(os.path.dirname(__file__), | |
'weights/vit_large_patch16_224.augreg_in21k_ft_in1k.bin') | |
def filter_fn(state_dict, _): | |
if num_classes is None: # use fine-tuned in1k fc head | |
return state_dict | |
else: # use a new linear | |
del state_dict['head.weight'] | |
del state_dict['head.bias'] | |
return state_dict | |
load_checkpoint(res, checkpoint_path, strict=False, filter_fn=filter_fn) | |
res.eval() | |
input_sample = torch.rand(2, 3, 224, 224) | |
o1 = res(input_sample) | |
make_attention_prunable(res) | |
res.eval() | |
o2 = res(input_sample) | |
assert ((o1 - o2) ** 2).sum() < 1e-5 | |
return res | |
from timm.models.vision_transformer import VisionTransformer | |
def vit_b_16(pretrained=True, num_classes=None) -> VisionTransformer: | |
# https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k | |
res = timm.create_model('vit_base_patch16_224.augreg_in21k_ft_in1k', | |
num_classes=num_classes) | |
if pretrained: | |
checkpoint_path = os.path.join(os.path.dirname(__file__), | |
'weights/vit_base_patch16_224.augreg_in21k_ft_in1k.bin') | |
def filter_fn(state_dict, _): | |
if num_classes is None: # use fine-tuned in1k fc head | |
return state_dict | |
else: # use a new linear | |
del state_dict['head.weight'] | |
del state_dict['head.bias'] | |
return state_dict | |
load_checkpoint(res, checkpoint_path, strict=False, filter_fn=filter_fn) | |
res.eval() | |
input_sample = torch.rand(2, 3, 224, 224) | |
o1 = res(input_sample) | |
logger.info(f'make attention prunable') | |
make_attention_prunable(res) | |
# logger.info(f'make softmax prunable') | |
# make_softmax_prunable(res) | |
res.eval() | |
o2 = res(input_sample) | |
# print(((o1 - o2) ** 2).sum()) | |
assert ((o1 - o2) ** 2).sum() < 1e-5 | |
return res | |
def make_softmax_prunable(model): | |
model.eval() | |
input_sample = torch.rand(2, 3, 224, 224).to(get_model_device(model)) | |
o1 = model(input_sample) | |
for name, module in model.named_modules(): | |
if isinstance(module, nn.Softmax): | |
set_module(model, name, SoftmaxIgnoringZero()) | |
logger.info(f'make softmax {name} prunable') | |
model.eval() | |
o2 = model(input_sample) | |
assert ((o1 - o2) ** 2).sum() < 1e-5 | |
return model | |
if __name__ == '__main__': | |
model = vit_l_16() | |
model(torch.rand((1, 3, 224, 224))) | |
# from utils.dl.common.data_loader import ImageNetDataLoader | |
# _, test_loader = ImageNetDataLoader('/data/zql/datasets/imagenet2012/train', '/data/zql/datasets/imagenet2012/val', 512, 8) | |
# import torch | |
# import tqdm | |
# import torch.nn.functional as F | |
# def get_accuracy(model, dataloader=test_loader, device='cuda'): | |
# acc = 0 | |
# sample_num = 0 | |
# model.eval() | |
# model = model.to(device) | |
# with torch.no_grad(): | |
# pbar = tqdm.tqdm(enumerate(dataloader), total=len(dataloader), dynamic_ncols=True, leave=False) | |
# for batch_index, (x, y) in pbar: | |
# x, y = x.to(device), y.to(device) | |
# output = model(x) | |
# pred = F.softmax(output, dim=1).argmax(dim=1) | |
# correct = torch.eq(pred, y).sum().item() | |
# acc += correct | |
# sample_num += len(y) | |
# pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' | |
# f'cur_batch_acc: {(correct / len(y)):.4f}') | |
# acc /= sample_num | |
# return acc | |
# model = model.cuda() | |
# print(f'vit_l_16 im1k acc: {get_accuracy(model, test_loader, "cuda")}') | |
# softmax = SoftmaxIgnoringZero() | |
# x = torch.tensor([[[1, 0, 3], [2, 2, 0]]] * 2).float() | |
# print(softmax(x)) | |
# model = vit_b_16(True) | |
# print(get_accuracy(model)) | |
# for name, module in model.named_modules(): | |
# if isinstance(module, nn.Softmax): | |
# set_module(model, name, SoftmaxIgnoringZero()) | |
# print(f'{name}') | |
# # print(model) | |
# print(get_accuracy(model)) | |
# softmax = SoftmaxIgnoringZero() | |
# linear = nn.Linear(20, 10) | |
# net = nn.Sequential(linear, softmax) | |
# optimizer = torch.optim.SGD(net.parameters(), lr=10, momentum=0.9) | |
# x = torch.rand((64, 20)) | |
# y_g = torch.rand((64, 10)) | |
# for _ in range(100): | |
# y = net(x) | |
# # print(y) | |
# loss = F.mse_loss(y, y_g) | |
# optimizer.zero_grad() | |
# loss.backward() | |
# # print(linear.weight.grad) | |
# optimizer.step() | |
# print(loss) | |
softmax = SoftmaxIgnoringZero() | |
x = torch.tensor([ | |
[1, 0, 2], | |
[4, 0, 9], | |
[0, 0, 0], | |
[1, 1, 1] | |
]).float() | |
print(softmax(x)) | |
x = torch.tensor([ | |
[1, 2], | |
[4, 9], | |
]).float() | |
print(softmax(x)) |