image2sketch / ptflops /pytorch_ops.py
sharazAhm890's picture
init
b4f7b8c verified
'''
Copyright (C) 2021 Sovrasov V. - All Rights Reserved
* You may use, distribute and modify this code under the
* terms of the MIT license.
* You should have received a copy of the MIT license with
* this file. If not visit https://opensource.org/licenses/MIT
'''
import numpy as np
import torch.nn as nn
def empty_flops_counter_hook(module, input, output):
module.__flops__ += 0
def upsample_flops_counter_hook(module, input, output):
output_size = output[0]
batch_size = output_size.shape[0]
output_elements_count = batch_size
for val in output_size.shape[1:]:
output_elements_count *= val
module.__flops__ += int(output_elements_count)
def relu_flops_counter_hook(module, input, output):
active_elements_count = output.numel()
module.__flops__ += int(active_elements_count)
def linear_flops_counter_hook(module, input, output):
input = input[0]
# pytorch checks dimensions, so here we don't care much
output_last_dim = output.shape[-1]
bias_flops = output_last_dim if module.bias is not None else 0
module.__flops__ += int(np.prod(input.shape) * output_last_dim + bias_flops)
def pool_flops_counter_hook(module, input, output):
input = input[0]
module.__flops__ += int(np.prod(input.shape))
def bn_flops_counter_hook(module, input, output):
input = input[0]
batch_flops = np.prod(input.shape)
if module.affine:
batch_flops *= 2
module.__flops__ += int(batch_flops)
def conv_flops_counter_hook(conv_module, input, output):
# Can have multiple inputs, getting the first one
input = input[0]
batch_size = input.shape[0]
output_dims = list(output.shape[2:])
kernel_dims = list(conv_module.kernel_size)
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
filters_per_channel = out_channels // groups
conv_per_position_flops = int(np.prod(kernel_dims)) * \
in_channels * filters_per_channel
active_elements_count = batch_size * int(np.prod(output_dims))
overall_conv_flops = conv_per_position_flops * active_elements_count
bias_flops = 0
if conv_module.bias is not None:
bias_flops = out_channels * active_elements_count
overall_flops = overall_conv_flops + bias_flops
conv_module.__flops__ += int(overall_flops)
def rnn_flops(flops, rnn_module, w_ih, w_hh, input_size):
# matrix matrix mult ih state and internal state
flops += w_ih.shape[0]*w_ih.shape[1]
# matrix matrix mult hh state and internal state
flops += w_hh.shape[0]*w_hh.shape[1]
if isinstance(rnn_module, (nn.RNN, nn.RNNCell)):
# add both operations
flops += rnn_module.hidden_size
elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)):
# hadamard of r
flops += rnn_module.hidden_size
# adding operations from both states
flops += rnn_module.hidden_size*3
# last two hadamard product and add
flops += rnn_module.hidden_size*3
elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)):
# adding operations from both states
flops += rnn_module.hidden_size*4
# two hadamard product and add for C state
flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
# final hadamard
flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
return flops
def rnn_flops_counter_hook(rnn_module, input, output):
"""
Takes into account batch goes at first position, contrary
to pytorch common rule (but actually it doesn't matter).
If sigmoid and tanh are hard, only a comparison FLOPS should be accurate
"""
flops = 0
# input is a tuple containing a sequence to process and (optionally) hidden state
inp = input[0]
batch_size = inp.shape[0]
seq_length = inp.shape[1]
num_layers = rnn_module.num_layers
for i in range(num_layers):
w_ih = rnn_module.__getattr__('weight_ih_l' + str(i))
w_hh = rnn_module.__getattr__('weight_hh_l' + str(i))
if i == 0:
input_size = rnn_module.input_size
else:
input_size = rnn_module.hidden_size
flops = rnn_flops(flops, rnn_module, w_ih, w_hh, input_size)
if rnn_module.bias:
b_ih = rnn_module.__getattr__('bias_ih_l' + str(i))
b_hh = rnn_module.__getattr__('bias_hh_l' + str(i))
flops += b_ih.shape[0] + b_hh.shape[0]
flops *= batch_size
flops *= seq_length
if rnn_module.bidirectional:
flops *= 2
rnn_module.__flops__ += int(flops)
def rnn_cell_flops_counter_hook(rnn_cell_module, input, output):
flops = 0
inp = input[0]
batch_size = inp.shape[0]
w_ih = rnn_cell_module.__getattr__('weight_ih')
w_hh = rnn_cell_module.__getattr__('weight_hh')
input_size = inp.shape[1]
flops = rnn_flops(flops, rnn_cell_module, w_ih, w_hh, input_size)
if rnn_cell_module.bias:
b_ih = rnn_cell_module.__getattr__('bias_ih')
b_hh = rnn_cell_module.__getattr__('bias_hh')
flops += b_ih.shape[0] + b_hh.shape[0]
flops *= batch_size
rnn_cell_module.__flops__ += int(flops)
def multihead_attention_counter_hook(multihead_attention_module, input, output):
flops = 0
q, k, v = input
batch_first = multihead_attention_module.batch_first \
if hasattr(multihead_attention_module, 'batch_first') else False
if batch_first:
batch_size = q.shape[0]
len_idx = 1
else:
batch_size = q.shape[1]
len_idx = 0
dim_idx = 2
qdim = q.shape[dim_idx]
kdim = k.shape[dim_idx]
vdim = v.shape[dim_idx]
qlen = q.shape[len_idx]
klen = k.shape[len_idx]
vlen = v.shape[len_idx]
num_heads = multihead_attention_module.num_heads
assert qdim == multihead_attention_module.embed_dim
if multihead_attention_module.kdim is None:
assert kdim == qdim
if multihead_attention_module.vdim is None:
assert vdim == qdim
flops = 0
# Q scaling
flops += qlen * qdim
# Initial projections
flops += (
(qlen * qdim * qdim) # QW
+ (klen * kdim * kdim) # KW
+ (vlen * vdim * vdim) # VW
)
if multihead_attention_module.in_proj_bias is not None:
flops += (qlen + klen + vlen) * qdim
# attention heads: scale, matmul, softmax, matmul
qk_head_dim = qdim // num_heads
v_head_dim = vdim // num_heads
head_flops = (
(qlen * klen * qk_head_dim) # QK^T
+ (qlen * klen) # softmax
+ (qlen * klen * v_head_dim) # AV
)
flops += num_heads * head_flops
# final projection, bias is always enabled
flops += qlen * vdim * (vdim + 1)
flops *= batch_size
multihead_attention_module.__flops__ += int(flops)
CUSTOM_MODULES_MAPPING = {}
MODULES_MAPPING = {
# convolutions
nn.Conv1d: conv_flops_counter_hook,
nn.Conv2d: conv_flops_counter_hook,
nn.Conv3d: conv_flops_counter_hook,
# activations
nn.ReLU: relu_flops_counter_hook,
nn.PReLU: relu_flops_counter_hook,
nn.ELU: relu_flops_counter_hook,
nn.LeakyReLU: relu_flops_counter_hook,
nn.ReLU6: relu_flops_counter_hook,
# poolings
nn.MaxPool1d: pool_flops_counter_hook,
nn.AvgPool1d: pool_flops_counter_hook,
nn.AvgPool2d: pool_flops_counter_hook,
nn.MaxPool2d: pool_flops_counter_hook,
nn.MaxPool3d: pool_flops_counter_hook,
nn.AvgPool3d: pool_flops_counter_hook,
nn.AdaptiveMaxPool1d: pool_flops_counter_hook,
nn.AdaptiveAvgPool1d: pool_flops_counter_hook,
nn.AdaptiveMaxPool2d: pool_flops_counter_hook,
nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
# BNs
nn.BatchNorm1d: bn_flops_counter_hook,
nn.BatchNorm2d: bn_flops_counter_hook,
nn.BatchNorm3d: bn_flops_counter_hook,
nn.InstanceNorm1d: bn_flops_counter_hook,
nn.InstanceNorm2d: bn_flops_counter_hook,
nn.InstanceNorm3d: bn_flops_counter_hook,
nn.GroupNorm: bn_flops_counter_hook,
# FC
nn.Linear: linear_flops_counter_hook,
# Upscale
nn.Upsample: upsample_flops_counter_hook,
# Deconvolution
nn.ConvTranspose1d: conv_flops_counter_hook,
nn.ConvTranspose2d: conv_flops_counter_hook,
nn.ConvTranspose3d: conv_flops_counter_hook,
# RNN
nn.RNN: rnn_flops_counter_hook,
nn.GRU: rnn_flops_counter_hook,
nn.LSTM: rnn_flops_counter_hook,
nn.RNNCell: rnn_cell_flops_counter_hook,
nn.LSTMCell: rnn_cell_flops_counter_hook,
nn.GRUCell: rnn_cell_flops_counter_hook,
nn.MultiheadAttention: multihead_attention_counter_hook
}
if hasattr(nn, 'GELU'):
MODULES_MAPPING[nn.GELU] = relu_flops_counter_hook