|
''' |
|
Copyright (C) 2021 Sovrasov V. - All Rights Reserved |
|
* You may use, distribute and modify this code under the |
|
* terms of the MIT license. |
|
* You should have received a copy of the MIT license with |
|
* this file. If not visit https://opensource.org/licenses/MIT |
|
''' |
|
|
|
import numpy as np |
|
import torch.nn as nn |
|
|
|
|
|
def empty_flops_counter_hook(module, input, output): |
|
module.__flops__ += 0 |
|
|
|
|
|
def upsample_flops_counter_hook(module, input, output): |
|
output_size = output[0] |
|
batch_size = output_size.shape[0] |
|
output_elements_count = batch_size |
|
for val in output_size.shape[1:]: |
|
output_elements_count *= val |
|
module.__flops__ += int(output_elements_count) |
|
|
|
|
|
def relu_flops_counter_hook(module, input, output): |
|
active_elements_count = output.numel() |
|
module.__flops__ += int(active_elements_count) |
|
|
|
|
|
def linear_flops_counter_hook(module, input, output): |
|
input = input[0] |
|
|
|
output_last_dim = output.shape[-1] |
|
bias_flops = output_last_dim if module.bias is not None else 0 |
|
module.__flops__ += int(np.prod(input.shape) * output_last_dim + bias_flops) |
|
|
|
|
|
def pool_flops_counter_hook(module, input, output): |
|
input = input[0] |
|
module.__flops__ += int(np.prod(input.shape)) |
|
|
|
|
|
def bn_flops_counter_hook(module, input, output): |
|
input = input[0] |
|
|
|
batch_flops = np.prod(input.shape) |
|
if module.affine: |
|
batch_flops *= 2 |
|
module.__flops__ += int(batch_flops) |
|
|
|
|
|
def conv_flops_counter_hook(conv_module, input, output): |
|
|
|
input = input[0] |
|
|
|
batch_size = input.shape[0] |
|
output_dims = list(output.shape[2:]) |
|
|
|
kernel_dims = list(conv_module.kernel_size) |
|
in_channels = conv_module.in_channels |
|
out_channels = conv_module.out_channels |
|
groups = conv_module.groups |
|
|
|
filters_per_channel = out_channels // groups |
|
conv_per_position_flops = int(np.prod(kernel_dims)) * \ |
|
in_channels * filters_per_channel |
|
|
|
active_elements_count = batch_size * int(np.prod(output_dims)) |
|
|
|
overall_conv_flops = conv_per_position_flops * active_elements_count |
|
|
|
bias_flops = 0 |
|
|
|
if conv_module.bias is not None: |
|
|
|
bias_flops = out_channels * active_elements_count |
|
|
|
overall_flops = overall_conv_flops + bias_flops |
|
|
|
conv_module.__flops__ += int(overall_flops) |
|
|
|
|
|
def rnn_flops(flops, rnn_module, w_ih, w_hh, input_size): |
|
|
|
flops += w_ih.shape[0]*w_ih.shape[1] |
|
|
|
flops += w_hh.shape[0]*w_hh.shape[1] |
|
if isinstance(rnn_module, (nn.RNN, nn.RNNCell)): |
|
|
|
flops += rnn_module.hidden_size |
|
elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)): |
|
|
|
flops += rnn_module.hidden_size |
|
|
|
flops += rnn_module.hidden_size*3 |
|
|
|
flops += rnn_module.hidden_size*3 |
|
elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)): |
|
|
|
flops += rnn_module.hidden_size*4 |
|
|
|
flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size |
|
|
|
flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size |
|
return flops |
|
|
|
|
|
def rnn_flops_counter_hook(rnn_module, input, output): |
|
""" |
|
Takes into account batch goes at first position, contrary |
|
to pytorch common rule (but actually it doesn't matter). |
|
If sigmoid and tanh are hard, only a comparison FLOPS should be accurate |
|
""" |
|
flops = 0 |
|
|
|
inp = input[0] |
|
batch_size = inp.shape[0] |
|
seq_length = inp.shape[1] |
|
num_layers = rnn_module.num_layers |
|
|
|
for i in range(num_layers): |
|
w_ih = rnn_module.__getattr__('weight_ih_l' + str(i)) |
|
w_hh = rnn_module.__getattr__('weight_hh_l' + str(i)) |
|
if i == 0: |
|
input_size = rnn_module.input_size |
|
else: |
|
input_size = rnn_module.hidden_size |
|
flops = rnn_flops(flops, rnn_module, w_ih, w_hh, input_size) |
|
if rnn_module.bias: |
|
b_ih = rnn_module.__getattr__('bias_ih_l' + str(i)) |
|
b_hh = rnn_module.__getattr__('bias_hh_l' + str(i)) |
|
flops += b_ih.shape[0] + b_hh.shape[0] |
|
|
|
flops *= batch_size |
|
flops *= seq_length |
|
if rnn_module.bidirectional: |
|
flops *= 2 |
|
rnn_module.__flops__ += int(flops) |
|
|
|
|
|
def rnn_cell_flops_counter_hook(rnn_cell_module, input, output): |
|
flops = 0 |
|
inp = input[0] |
|
batch_size = inp.shape[0] |
|
w_ih = rnn_cell_module.__getattr__('weight_ih') |
|
w_hh = rnn_cell_module.__getattr__('weight_hh') |
|
input_size = inp.shape[1] |
|
flops = rnn_flops(flops, rnn_cell_module, w_ih, w_hh, input_size) |
|
if rnn_cell_module.bias: |
|
b_ih = rnn_cell_module.__getattr__('bias_ih') |
|
b_hh = rnn_cell_module.__getattr__('bias_hh') |
|
flops += b_ih.shape[0] + b_hh.shape[0] |
|
|
|
flops *= batch_size |
|
rnn_cell_module.__flops__ += int(flops) |
|
|
|
|
|
def multihead_attention_counter_hook(multihead_attention_module, input, output): |
|
flops = 0 |
|
|
|
q, k, v = input |
|
|
|
batch_first = multihead_attention_module.batch_first \ |
|
if hasattr(multihead_attention_module, 'batch_first') else False |
|
if batch_first: |
|
batch_size = q.shape[0] |
|
len_idx = 1 |
|
else: |
|
batch_size = q.shape[1] |
|
len_idx = 0 |
|
|
|
dim_idx = 2 |
|
|
|
qdim = q.shape[dim_idx] |
|
kdim = k.shape[dim_idx] |
|
vdim = v.shape[dim_idx] |
|
|
|
qlen = q.shape[len_idx] |
|
klen = k.shape[len_idx] |
|
vlen = v.shape[len_idx] |
|
|
|
num_heads = multihead_attention_module.num_heads |
|
assert qdim == multihead_attention_module.embed_dim |
|
|
|
if multihead_attention_module.kdim is None: |
|
assert kdim == qdim |
|
if multihead_attention_module.vdim is None: |
|
assert vdim == qdim |
|
|
|
flops = 0 |
|
|
|
|
|
flops += qlen * qdim |
|
|
|
|
|
flops += ( |
|
(qlen * qdim * qdim) |
|
+ (klen * kdim * kdim) |
|
+ (vlen * vdim * vdim) |
|
) |
|
|
|
if multihead_attention_module.in_proj_bias is not None: |
|
flops += (qlen + klen + vlen) * qdim |
|
|
|
|
|
qk_head_dim = qdim // num_heads |
|
v_head_dim = vdim // num_heads |
|
|
|
head_flops = ( |
|
(qlen * klen * qk_head_dim) |
|
+ (qlen * klen) |
|
+ (qlen * klen * v_head_dim) |
|
) |
|
|
|
flops += num_heads * head_flops |
|
|
|
|
|
flops += qlen * vdim * (vdim + 1) |
|
|
|
flops *= batch_size |
|
multihead_attention_module.__flops__ += int(flops) |
|
|
|
|
|
CUSTOM_MODULES_MAPPING = {} |
|
|
|
MODULES_MAPPING = { |
|
|
|
nn.Conv1d: conv_flops_counter_hook, |
|
nn.Conv2d: conv_flops_counter_hook, |
|
nn.Conv3d: conv_flops_counter_hook, |
|
|
|
nn.ReLU: relu_flops_counter_hook, |
|
nn.PReLU: relu_flops_counter_hook, |
|
nn.ELU: relu_flops_counter_hook, |
|
nn.LeakyReLU: relu_flops_counter_hook, |
|
nn.ReLU6: relu_flops_counter_hook, |
|
|
|
nn.MaxPool1d: pool_flops_counter_hook, |
|
nn.AvgPool1d: pool_flops_counter_hook, |
|
nn.AvgPool2d: pool_flops_counter_hook, |
|
nn.MaxPool2d: pool_flops_counter_hook, |
|
nn.MaxPool3d: pool_flops_counter_hook, |
|
nn.AvgPool3d: pool_flops_counter_hook, |
|
nn.AdaptiveMaxPool1d: pool_flops_counter_hook, |
|
nn.AdaptiveAvgPool1d: pool_flops_counter_hook, |
|
nn.AdaptiveMaxPool2d: pool_flops_counter_hook, |
|
nn.AdaptiveAvgPool2d: pool_flops_counter_hook, |
|
nn.AdaptiveMaxPool3d: pool_flops_counter_hook, |
|
nn.AdaptiveAvgPool3d: pool_flops_counter_hook, |
|
|
|
nn.BatchNorm1d: bn_flops_counter_hook, |
|
nn.BatchNorm2d: bn_flops_counter_hook, |
|
nn.BatchNorm3d: bn_flops_counter_hook, |
|
|
|
nn.InstanceNorm1d: bn_flops_counter_hook, |
|
nn.InstanceNorm2d: bn_flops_counter_hook, |
|
nn.InstanceNorm3d: bn_flops_counter_hook, |
|
nn.GroupNorm: bn_flops_counter_hook, |
|
|
|
nn.Linear: linear_flops_counter_hook, |
|
|
|
nn.Upsample: upsample_flops_counter_hook, |
|
|
|
nn.ConvTranspose1d: conv_flops_counter_hook, |
|
nn.ConvTranspose2d: conv_flops_counter_hook, |
|
nn.ConvTranspose3d: conv_flops_counter_hook, |
|
|
|
nn.RNN: rnn_flops_counter_hook, |
|
nn.GRU: rnn_flops_counter_hook, |
|
nn.LSTM: rnn_flops_counter_hook, |
|
nn.RNNCell: rnn_cell_flops_counter_hook, |
|
nn.LSTMCell: rnn_cell_flops_counter_hook, |
|
nn.GRUCell: rnn_cell_flops_counter_hook, |
|
nn.MultiheadAttention: multihead_attention_counter_hook |
|
} |
|
|
|
if hasattr(nn, 'GELU'): |
|
MODULES_MAPPING[nn.GELU] = relu_flops_counter_hook |
|
|