|
''' |
|
Copyright (C) 2021 Sovrasov V. - All Rights Reserved |
|
* You may use, distribute and modify this code under the |
|
* terms of the MIT license. |
|
* You should have received a copy of the MIT license with |
|
* this file. If not visit https://opensource.org/licenses/MIT |
|
''' |
|
|
|
import sys |
|
from functools import partial |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .pytorch_ops import CUSTOM_MODULES_MAPPING, MODULES_MAPPING |
|
from .utils import flops_to_string, params_to_string |
|
|
|
|
|
def get_flops_pytorch(model, input_res, |
|
print_per_layer_stat=True, |
|
input_constructor=None, ost=sys.stdout, |
|
verbose=False, ignore_modules=[], |
|
custom_modules_hooks={}, |
|
output_precision=3, |
|
flops_units='GMac', |
|
param_units='M'): |
|
global CUSTOM_MODULES_MAPPING |
|
CUSTOM_MODULES_MAPPING = custom_modules_hooks |
|
flops_model = add_flops_counting_methods(model) |
|
flops_model.eval() |
|
flops_model.start_flops_count(ost=ost, verbose=verbose, |
|
ignore_list=ignore_modules) |
|
if input_constructor: |
|
input = input_constructor(input_res) |
|
_ = flops_model(**input) |
|
else: |
|
try: |
|
batch = torch.ones(()).new_empty((1, *input_res), |
|
dtype=next(flops_model.parameters()).dtype, |
|
device=next(flops_model.parameters()).device) |
|
except StopIteration: |
|
batch = torch.ones(()).new_empty((1, *input_res)) |
|
|
|
_ = flops_model(batch) |
|
|
|
flops_count, params_count = flops_model.compute_average_flops_cost() |
|
if print_per_layer_stat: |
|
print_model_with_flops( |
|
flops_model, |
|
flops_count, |
|
params_count, |
|
ost=ost, |
|
flops_units=flops_units, |
|
param_units=param_units, |
|
precision=output_precision |
|
) |
|
flops_model.stop_flops_count() |
|
CUSTOM_MODULES_MAPPING = {} |
|
|
|
return flops_count, params_count |
|
|
|
|
|
def accumulate_flops(self): |
|
if is_supported_instance(self): |
|
return self.__flops__ |
|
else: |
|
sum = 0 |
|
for m in self.children(): |
|
sum += m.accumulate_flops() |
|
return sum |
|
|
|
|
|
def print_model_with_flops(model, total_flops, total_params, flops_units='GMac', |
|
param_units='M', precision=3, ost=sys.stdout): |
|
if total_flops < 1: |
|
total_flops = 1 |
|
if total_params < 1: |
|
total_params = 1 |
|
|
|
def accumulate_params(self): |
|
if is_supported_instance(self): |
|
return self.__params__ |
|
else: |
|
sum = 0 |
|
for m in self.children(): |
|
sum += m.accumulate_params() |
|
return sum |
|
|
|
def flops_repr(self): |
|
accumulated_params_num = self.accumulate_params() |
|
accumulated_flops_cost = self.accumulate_flops() / model.__batch_counter__ |
|
return ', '.join([params_to_string(accumulated_params_num, |
|
units=param_units, precision=precision), |
|
'{:.3%} Params'.format(accumulated_params_num / total_params), |
|
flops_to_string(accumulated_flops_cost, |
|
units=flops_units, precision=precision), |
|
'{:.3%} MACs'.format(accumulated_flops_cost / total_flops), |
|
self.original_extra_repr()]) |
|
|
|
def add_extra_repr(m): |
|
m.accumulate_flops = accumulate_flops.__get__(m) |
|
m.accumulate_params = accumulate_params.__get__(m) |
|
flops_extra_repr = flops_repr.__get__(m) |
|
if m.extra_repr != flops_extra_repr: |
|
m.original_extra_repr = m.extra_repr |
|
m.extra_repr = flops_extra_repr |
|
assert m.extra_repr != m.original_extra_repr |
|
|
|
def del_extra_repr(m): |
|
if hasattr(m, 'original_extra_repr'): |
|
m.extra_repr = m.original_extra_repr |
|
del m.original_extra_repr |
|
if hasattr(m, 'accumulate_flops'): |
|
del m.accumulate_flops |
|
|
|
model.apply(add_extra_repr) |
|
print(repr(model), file=ost) |
|
model.apply(del_extra_repr) |
|
|
|
|
|
def get_model_parameters_number(model): |
|
params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
return params_num |
|
|
|
|
|
def add_flops_counting_methods(net_main_module): |
|
|
|
|
|
net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) |
|
net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) |
|
net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) |
|
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( |
|
net_main_module) |
|
|
|
net_main_module.reset_flops_count() |
|
|
|
return net_main_module |
|
|
|
|
|
def compute_average_flops_cost(self): |
|
""" |
|
A method that will be available after add_flops_counting_methods() is called |
|
on a desired net object. |
|
|
|
Returns current mean flops consumption per image. |
|
|
|
""" |
|
|
|
for m in self.modules(): |
|
m.accumulate_flops = accumulate_flops.__get__(m) |
|
|
|
flops_sum = self.accumulate_flops() |
|
|
|
for m in self.modules(): |
|
if hasattr(m, 'accumulate_flops'): |
|
del m.accumulate_flops |
|
|
|
params_sum = get_model_parameters_number(self) |
|
return flops_sum / self.__batch_counter__, params_sum |
|
|
|
|
|
def start_flops_count(self, **kwargs): |
|
""" |
|
A method that will be available after add_flops_counting_methods() is called |
|
on a desired net object. |
|
|
|
Activates the computation of mean flops consumption per image. |
|
Call it before you run the network. |
|
|
|
""" |
|
add_batch_counter_hook_function(self) |
|
|
|
seen_types = set() |
|
|
|
def add_flops_counter_hook_function(module, ost, verbose, ignore_list): |
|
if type(module) in ignore_list: |
|
seen_types.add(type(module)) |
|
if is_supported_instance(module): |
|
module.__params__ = 0 |
|
elif is_supported_instance(module): |
|
if hasattr(module, '__flops_handle__'): |
|
return |
|
if type(module) in CUSTOM_MODULES_MAPPING: |
|
handle = module.register_forward_hook( |
|
CUSTOM_MODULES_MAPPING[type(module)]) |
|
else: |
|
handle = module.register_forward_hook(MODULES_MAPPING[type(module)]) |
|
module.__flops_handle__ = handle |
|
seen_types.add(type(module)) |
|
else: |
|
if verbose and not type(module) in (nn.Sequential, nn.ModuleList) and \ |
|
not type(module) in seen_types: |
|
print('Warning: module ' + type(module).__name__ + |
|
' is treated as a zero-op.', file=ost) |
|
seen_types.add(type(module)) |
|
|
|
self.apply(partial(add_flops_counter_hook_function, **kwargs)) |
|
|
|
|
|
def stop_flops_count(self): |
|
""" |
|
A method that will be available after add_flops_counting_methods() is called |
|
on a desired net object. |
|
|
|
Stops computing the mean flops consumption per image. |
|
Call whenever you want to pause the computation. |
|
|
|
""" |
|
remove_batch_counter_hook_function(self) |
|
self.apply(remove_flops_counter_hook_function) |
|
self.apply(remove_flops_counter_variables) |
|
|
|
|
|
def reset_flops_count(self): |
|
""" |
|
A method that will be available after add_flops_counting_methods() is called |
|
on a desired net object. |
|
|
|
Resets statistics computed so far. |
|
|
|
""" |
|
add_batch_counter_variables_or_reset(self) |
|
self.apply(add_flops_counter_variable_or_reset) |
|
|
|
|
|
|
|
def batch_counter_hook(module, input, output): |
|
batch_size = 1 |
|
if len(input) > 0: |
|
|
|
input = input[0] |
|
batch_size = len(input) |
|
else: |
|
pass |
|
print('Warning! No positional inputs found for a module,' |
|
' assuming batch size is 1.') |
|
module.__batch_counter__ += batch_size |
|
|
|
|
|
def add_batch_counter_variables_or_reset(module): |
|
|
|
module.__batch_counter__ = 0 |
|
|
|
|
|
def add_batch_counter_hook_function(module): |
|
if hasattr(module, '__batch_counter_handle__'): |
|
return |
|
|
|
handle = module.register_forward_hook(batch_counter_hook) |
|
module.__batch_counter_handle__ = handle |
|
|
|
|
|
def remove_batch_counter_hook_function(module): |
|
if hasattr(module, '__batch_counter_handle__'): |
|
module.__batch_counter_handle__.remove() |
|
del module.__batch_counter_handle__ |
|
|
|
|
|
def add_flops_counter_variable_or_reset(module): |
|
if is_supported_instance(module): |
|
if hasattr(module, '__flops__') or hasattr(module, '__params__'): |
|
print('Warning: variables __flops__ or __params__ are already ' |
|
'defined for the module' + type(module).__name__ + |
|
' ptflops can affect your code!') |
|
module.__ptflops_backup_flops__ = module.__flops__ |
|
module.__ptflops_backup_params__ = module.__params__ |
|
module.__flops__ = 0 |
|
module.__params__ = get_model_parameters_number(module) |
|
|
|
|
|
def is_supported_instance(module): |
|
if type(module) in MODULES_MAPPING or type(module) in CUSTOM_MODULES_MAPPING: |
|
return True |
|
return False |
|
|
|
|
|
def remove_flops_counter_hook_function(module): |
|
if is_supported_instance(module): |
|
if hasattr(module, '__flops_handle__'): |
|
module.__flops_handle__.remove() |
|
del module.__flops_handle__ |
|
|
|
|
|
def remove_flops_counter_variables(module): |
|
if is_supported_instance(module): |
|
if hasattr(module, '__flops__'): |
|
del module.__flops__ |
|
if hasattr(module, '__ptflops_backup_flops__'): |
|
module.__flops__ = module.__ptflops_backup_flops__ |
|
if hasattr(module, '__params__'): |
|
del module.__params__ |
|
if hasattr(module, '__ptflops_backup_params__'): |
|
module.__params__ = module.__ptflops_backup_params__ |
|
|