|
from . import utils |
|
|
|
import functools |
|
|
|
import torch |
|
|
|
def make_cast_wrapper(orig_fn, cast_fn, handle, |
|
try_caching=False): |
|
@functools.wraps(orig_fn) |
|
def wrapper(*args, **kwargs): |
|
if not handle.is_active(): |
|
return orig_fn(*args, **kwargs) |
|
|
|
input_types = [ |
|
v.data.type() for v in list(args) + list(kwargs.values()) |
|
if utils.is_fp_tensor(v) |
|
] |
|
|
|
input_type = input_types[0] |
|
|
|
if try_caching and handle.has_cache: |
|
args = list(args) |
|
for i in range(len(args)): |
|
if utils.should_cache(args[i]): |
|
args[i] = utils.cached_cast(cast_fn, args[i], handle.cache) |
|
for k in kwargs: |
|
if utils.should_cache(kwargs[k]): |
|
kwargs[k] = utils.cached_cast(cast_fn, kwargs[k], handle.cache) |
|
new_args = utils.casted_args(cast_fn, |
|
args, |
|
kwargs) |
|
output = orig_fn(*new_args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
return cast_output(output, input_type, verbose=False) |
|
|
|
return wrapper |
|
|
|
def cast_output(output, input_type, verbose=False): |
|
if isinstance(output, dict): |
|
keys = output.keys() |
|
for k in keys: |
|
output[k] = cast_output(output[k], input_type) |
|
return output |
|
|
|
if utils.is_fp_tensor(output) and output.type() != input_type: |
|
if verbose: |
|
print('ori output type: {}, input type: {}'.format(output.type(), input_type)) |
|
return output.type(input_type) |
|
return output |
|
|
|
def cached_cast(mod, fn, cast_fn, handle, |
|
try_caching=False, verbose=False): |
|
if not utils.has_func(mod, fn): |
|
return |
|
|
|
orig_fn = utils.get_func(mod, fn) |
|
cast_fn = utils.verbosify(cast_fn, fn, verbose) |
|
wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching) |
|
utils.set_func_save(handle, mod, fn, wrapper) |
|
|
|
|