image2sketch / ptflops /pytorch_engine.py
sharazAhm890's picture
init
b4f7b8c verified
'''
Copyright (C) 2021 Sovrasov V. - All Rights Reserved
* You may use, distribute and modify this code under the
* terms of the MIT license.
* You should have received a copy of the MIT license with
* this file. If not visit https://opensource.org/licenses/MIT
'''
import sys
from functools import partial
import torch
import torch.nn as nn
from .pytorch_ops import CUSTOM_MODULES_MAPPING, MODULES_MAPPING
from .utils import flops_to_string, params_to_string
def get_flops_pytorch(model, input_res,
print_per_layer_stat=True,
input_constructor=None, ost=sys.stdout,
verbose=False, ignore_modules=[],
custom_modules_hooks={},
output_precision=3,
flops_units='GMac',
param_units='M'):
global CUSTOM_MODULES_MAPPING
CUSTOM_MODULES_MAPPING = custom_modules_hooks
flops_model = add_flops_counting_methods(model)
flops_model.eval()
flops_model.start_flops_count(ost=ost, verbose=verbose,
ignore_list=ignore_modules)
if input_constructor:
input = input_constructor(input_res)
_ = flops_model(**input)
else:
try:
batch = torch.ones(()).new_empty((1, *input_res),
dtype=next(flops_model.parameters()).dtype,
device=next(flops_model.parameters()).device)
except StopIteration:
batch = torch.ones(()).new_empty((1, *input_res))
_ = flops_model(batch)
flops_count, params_count = flops_model.compute_average_flops_cost()
if print_per_layer_stat:
print_model_with_flops(
flops_model,
flops_count,
params_count,
ost=ost,
flops_units=flops_units,
param_units=param_units,
precision=output_precision
)
flops_model.stop_flops_count()
CUSTOM_MODULES_MAPPING = {}
return flops_count, params_count
def accumulate_flops(self):
if is_supported_instance(self):
return self.__flops__
else:
sum = 0
for m in self.children():
sum += m.accumulate_flops()
return sum
def print_model_with_flops(model, total_flops, total_params, flops_units='GMac',
param_units='M', precision=3, ost=sys.stdout):
if total_flops < 1:
total_flops = 1
if total_params < 1:
total_params = 1
def accumulate_params(self):
if is_supported_instance(self):
return self.__params__
else:
sum = 0
for m in self.children():
sum += m.accumulate_params()
return sum
def flops_repr(self):
accumulated_params_num = self.accumulate_params()
accumulated_flops_cost = self.accumulate_flops() / model.__batch_counter__
return ', '.join([params_to_string(accumulated_params_num,
units=param_units, precision=precision),
'{:.3%} Params'.format(accumulated_params_num / total_params),
flops_to_string(accumulated_flops_cost,
units=flops_units, precision=precision),
'{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
self.original_extra_repr()])
def add_extra_repr(m):
m.accumulate_flops = accumulate_flops.__get__(m)
m.accumulate_params = accumulate_params.__get__(m)
flops_extra_repr = flops_repr.__get__(m)
if m.extra_repr != flops_extra_repr:
m.original_extra_repr = m.extra_repr
m.extra_repr = flops_extra_repr
assert m.extra_repr != m.original_extra_repr
def del_extra_repr(m):
if hasattr(m, 'original_extra_repr'):
m.extra_repr = m.original_extra_repr
del m.original_extra_repr
if hasattr(m, 'accumulate_flops'):
del m.accumulate_flops
model.apply(add_extra_repr)
print(repr(model), file=ost)
model.apply(del_extra_repr)
def get_model_parameters_number(model):
params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
return params_num
def add_flops_counting_methods(net_main_module):
# adding additional methods to the existing module object,
# this is done this way so that each function has access to self object
net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(
net_main_module)
net_main_module.reset_flops_count()
return net_main_module
def compute_average_flops_cost(self):
"""
A method that will be available after add_flops_counting_methods() is called
on a desired net object.
Returns current mean flops consumption per image.
"""
for m in self.modules():
m.accumulate_flops = accumulate_flops.__get__(m)
flops_sum = self.accumulate_flops()
for m in self.modules():
if hasattr(m, 'accumulate_flops'):
del m.accumulate_flops
params_sum = get_model_parameters_number(self)
return flops_sum / self.__batch_counter__, params_sum
def start_flops_count(self, **kwargs):
"""
A method that will be available after add_flops_counting_methods() is called
on a desired net object.
Activates the computation of mean flops consumption per image.
Call it before you run the network.
"""
add_batch_counter_hook_function(self)
seen_types = set()
def add_flops_counter_hook_function(module, ost, verbose, ignore_list):
if type(module) in ignore_list:
seen_types.add(type(module))
if is_supported_instance(module):
module.__params__ = 0
elif is_supported_instance(module):
if hasattr(module, '__flops_handle__'):
return
if type(module) in CUSTOM_MODULES_MAPPING:
handle = module.register_forward_hook(
CUSTOM_MODULES_MAPPING[type(module)])
else:
handle = module.register_forward_hook(MODULES_MAPPING[type(module)])
module.__flops_handle__ = handle
seen_types.add(type(module))
else:
if verbose and not type(module) in (nn.Sequential, nn.ModuleList) and \
not type(module) in seen_types:
print('Warning: module ' + type(module).__name__ +
' is treated as a zero-op.', file=ost)
seen_types.add(type(module))
self.apply(partial(add_flops_counter_hook_function, **kwargs))
def stop_flops_count(self):
"""
A method that will be available after add_flops_counting_methods() is called
on a desired net object.
Stops computing the mean flops consumption per image.
Call whenever you want to pause the computation.
"""
remove_batch_counter_hook_function(self)
self.apply(remove_flops_counter_hook_function)
self.apply(remove_flops_counter_variables)
def reset_flops_count(self):
"""
A method that will be available after add_flops_counting_methods() is called
on a desired net object.
Resets statistics computed so far.
"""
add_batch_counter_variables_or_reset(self)
self.apply(add_flops_counter_variable_or_reset)
# ---- Internal functions
def batch_counter_hook(module, input, output):
batch_size = 1
if len(input) > 0:
# Can have multiple inputs, getting the first one
input = input[0]
batch_size = len(input)
else:
pass
print('Warning! No positional inputs found for a module,'
' assuming batch size is 1.')
module.__batch_counter__ += batch_size
def add_batch_counter_variables_or_reset(module):
module.__batch_counter__ = 0
def add_batch_counter_hook_function(module):
if hasattr(module, '__batch_counter_handle__'):
return
handle = module.register_forward_hook(batch_counter_hook)
module.__batch_counter_handle__ = handle
def remove_batch_counter_hook_function(module):
if hasattr(module, '__batch_counter_handle__'):
module.__batch_counter_handle__.remove()
del module.__batch_counter_handle__
def add_flops_counter_variable_or_reset(module):
if is_supported_instance(module):
if hasattr(module, '__flops__') or hasattr(module, '__params__'):
print('Warning: variables __flops__ or __params__ are already '
'defined for the module' + type(module).__name__ +
' ptflops can affect your code!')
module.__ptflops_backup_flops__ = module.__flops__
module.__ptflops_backup_params__ = module.__params__
module.__flops__ = 0
module.__params__ = get_model_parameters_number(module)
def is_supported_instance(module):
if type(module) in MODULES_MAPPING or type(module) in CUSTOM_MODULES_MAPPING:
return True
return False
def remove_flops_counter_hook_function(module):
if is_supported_instance(module):
if hasattr(module, '__flops_handle__'):
module.__flops_handle__.remove()
del module.__flops_handle__
def remove_flops_counter_variables(module):
if is_supported_instance(module):
if hasattr(module, '__flops__'):
del module.__flops__
if hasattr(module, '__ptflops_backup_flops__'):
module.__flops__ = module.__ptflops_backup_flops__
if hasattr(module, '__params__'):
del module.__params__
if hasattr(module, '__ptflops_backup_params__'):
module.__params__ = module.__ptflops_backup_params__