Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import functools | |
import warnings | |
from inspect import getfullargspec | |
import torch | |
from .utils import cast_tensor_type | |
def auto_fp16(apply_to=None, out_fp32=False): | |
"""Decorator to enable fp16 training automatically. | |
This decorator is useful when you write custom modules and want to support | |
mixed precision training. If inputs arguments are fp32 tensors, they will | |
be converted to fp16 automatically. Arguments other than fp32 tensors are | |
ignored. | |
Args: | |
apply_to (Iterable, optional): The argument names to be converted. | |
`None` indicates all arguments. | |
out_fp32 (bool): Whether to convert the output back to fp32. | |
Example: | |
>>> import torch.nn as nn | |
>>> class MyModule1(nn.Module): | |
>>> | |
>>> # Convert x and y to fp16 | |
>>> @auto_fp16() | |
>>> def forward(self, x, y): | |
>>> pass | |
>>> import torch.nn as nn | |
>>> class MyModule2(nn.Module): | |
>>> | |
>>> # convert pred to fp16 | |
>>> @auto_fp16(apply_to=('pred', )) | |
>>> def do_something(self, pred, others): | |
>>> pass | |
""" | |
warnings.warn( | |
'auto_fp16 in mmpose will be deprecated in the next release.' | |
'Please use mmcv.runner.auto_fp16 instead (mmcv>=1.3.1).', | |
DeprecationWarning) | |
def auto_fp16_wrapper(old_func): | |
def new_func(*args, **kwargs): | |
# check if the module has set the attribute `fp16_enabled`, if not, | |
# just fallback to the original method. | |
if not isinstance(args[0], torch.nn.Module): | |
raise TypeError('@auto_fp16 can only be used to decorate the ' | |
'method of nn.Module') | |
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): | |
return old_func(*args, **kwargs) | |
# get the arg spec of the decorated method | |
args_info = getfullargspec(old_func) | |
# get the argument names to be casted | |
args_to_cast = args_info.args if apply_to is None else apply_to | |
# convert the args that need to be processed | |
new_args = [] | |
# NOTE: default args are not taken into consideration | |
if args: | |
arg_names = args_info.args[:len(args)] | |
for i, arg_name in enumerate(arg_names): | |
if arg_name in args_to_cast: | |
new_args.append( | |
cast_tensor_type(args[i], torch.float, torch.half)) | |
else: | |
new_args.append(args[i]) | |
# convert the kwargs that need to be processed | |
new_kwargs = {} | |
if kwargs: | |
for arg_name, arg_value in kwargs.items(): | |
if arg_name in args_to_cast: | |
new_kwargs[arg_name] = cast_tensor_type( | |
arg_value, torch.float, torch.half) | |
else: | |
new_kwargs[arg_name] = arg_value | |
# apply converted arguments to the decorated method | |
output = old_func(*new_args, **new_kwargs) | |
# cast the results back to fp32 if necessary | |
if out_fp32: | |
output = cast_tensor_type(output, torch.half, torch.float) | |
return output | |
return new_func | |
return auto_fp16_wrapper | |
def force_fp32(apply_to=None, out_fp16=False): | |
"""Decorator to convert input arguments to fp32 in force. | |
This decorator is useful when you write custom modules and want to support | |
mixed precision training. If there are some inputs that must be processed | |
in fp32 mode, then this decorator can handle it. If inputs arguments are | |
fp16 tensors, they will be converted to fp32 automatically. Arguments other | |
than fp16 tensors are ignored. | |
Args: | |
apply_to (Iterable, optional): The argument names to be converted. | |
`None` indicates all arguments. | |
out_fp16 (bool): Whether to convert the output back to fp16. | |
Example: | |
>>> import torch.nn as nn | |
>>> class MyModule1(nn.Module): | |
>>> | |
>>> # Convert x and y to fp32 | |
>>> @force_fp32() | |
>>> def loss(self, x, y): | |
>>> pass | |
>>> import torch.nn as nn | |
>>> class MyModule2(nn.Module): | |
>>> | |
>>> # convert pred to fp32 | |
>>> @force_fp32(apply_to=('pred', )) | |
>>> def post_process(self, pred, others): | |
>>> pass | |
""" | |
warnings.warn( | |
'force_fp32 in mmpose will be deprecated in the next release.' | |
'Please use mmcv.runner.force_fp32 instead (mmcv>=1.3.1).', | |
DeprecationWarning) | |
def force_fp32_wrapper(old_func): | |
def new_func(*args, **kwargs): | |
# check if the module has set the attribute `fp16_enabled`, if not, | |
# just fallback to the original method. | |
if not isinstance(args[0], torch.nn.Module): | |
raise TypeError('@force_fp32 can only be used to decorate the ' | |
'method of nn.Module') | |
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): | |
return old_func(*args, **kwargs) | |
# get the arg spec of the decorated method | |
args_info = getfullargspec(old_func) | |
# get the argument names to be casted | |
args_to_cast = args_info.args if apply_to is None else apply_to | |
# convert the args that need to be processed | |
new_args = [] | |
if args: | |
arg_names = args_info.args[:len(args)] | |
for i, arg_name in enumerate(arg_names): | |
if arg_name in args_to_cast: | |
new_args.append( | |
cast_tensor_type(args[i], torch.half, torch.float)) | |
else: | |
new_args.append(args[i]) | |
# convert the kwargs that need to be processed | |
new_kwargs = dict() | |
if kwargs: | |
for arg_name, arg_value in kwargs.items(): | |
if arg_name in args_to_cast: | |
new_kwargs[arg_name] = cast_tensor_type( | |
arg_value, torch.half, torch.float) | |
else: | |
new_kwargs[arg_name] = arg_value | |
# apply converted arguments to the decorated method | |
output = old_func(*new_args, **new_kwargs) | |
# cast the results back to fp32 if necessary | |
if out_fp16: | |
output = cast_tensor_type(output, torch.float, torch.half) | |
return output | |
return new_func | |
return force_fp32_wrapper | |