tuandunghcmut's picture
Upload folder using huggingface_hub
345ee20 verified
from . import utils
import functools
import torch
def make_cast_wrapper(orig_fn, cast_fn, handle,
try_caching=False):
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
if not handle.is_active():
return orig_fn(*args, **kwargs)
input_types = [
v.data.type() for v in list(args) + list(kwargs.values())
if utils.is_fp_tensor(v)
]
#print('wrapper: orig_fn:{}, input_types:{}'.format(orig_fn, input_types))
input_type = input_types[0]
if try_caching and handle.has_cache:
args = list(args)
for i in range(len(args)):
if utils.should_cache(args[i]):
args[i] = utils.cached_cast(cast_fn, args[i], handle.cache)
for k in kwargs:
if utils.should_cache(kwargs[k]):
kwargs[k] = utils.cached_cast(cast_fn, kwargs[k], handle.cache)
new_args = utils.casted_args(cast_fn,
args,
kwargs)
output = orig_fn(*new_args, **kwargs)
#if output.type() != input_type:
# print('ori output type: {}, input type: {}'.format(output.type(), input_type))
# return output.type(input_type)
#return output
return cast_output(output, input_type, verbose=False)
return wrapper
def cast_output(output, input_type, verbose=False):
if isinstance(output, dict):
keys = output.keys()
for k in keys:
output[k] = cast_output(output[k], input_type)
return output
if utils.is_fp_tensor(output) and output.type() != input_type:
if verbose:
print('ori output type: {}, input type: {}'.format(output.type(), input_type))
return output.type(input_type)
return output
def cached_cast(mod, fn, cast_fn, handle,
try_caching=False, verbose=False):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)
utils.set_func_save(handle, mod, fn, wrapper)