File size: 2,257 Bytes
345ee20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from . import utils

import functools

import torch

def make_cast_wrapper(orig_fn, cast_fn, handle,
                      try_caching=False):
    @functools.wraps(orig_fn)
    def wrapper(*args, **kwargs):
        if not handle.is_active():
            return orig_fn(*args, **kwargs)
        
        input_types = [
            v.data.type() for v in list(args) + list(kwargs.values())
                if utils.is_fp_tensor(v)
        ]
        #print('wrapper: orig_fn:{}, input_types:{}'.format(orig_fn, input_types))
        input_type = input_types[0]

        if try_caching and handle.has_cache:
            args = list(args)
            for i in range(len(args)):
                if utils.should_cache(args[i]):
                    args[i] = utils.cached_cast(cast_fn, args[i], handle.cache)
            for k in kwargs:
                if utils.should_cache(kwargs[k]):
                    kwargs[k] = utils.cached_cast(cast_fn, kwargs[k], handle.cache)
        new_args = utils.casted_args(cast_fn,
                                     args,
                                     kwargs)
        output = orig_fn(*new_args, **kwargs)
        
        #if output.type() != input_type:
        #    print('ori output type: {}, input type: {}'.format(output.type(), input_type))
        #    return output.type(input_type)    
        #return output
        return cast_output(output, input_type, verbose=False)

    return wrapper

def cast_output(output, input_type, verbose=False):
    if isinstance(output, dict):
        keys = output.keys()
        for k in keys:
            output[k] = cast_output(output[k], input_type)
        return output
    
    if utils.is_fp_tensor(output) and output.type() != input_type:
        if verbose:
            print('ori output type: {}, input type: {}'.format(output.type(), input_type))
        return output.type(input_type)
    return output

def cached_cast(mod, fn, cast_fn, handle,
                try_caching=False, verbose=False):
    if not utils.has_func(mod, fn):
        return

    orig_fn = utils.get_func(mod, fn)
    cast_fn = utils.verbosify(cast_fn, fn, verbose)
    wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)
    utils.set_func_save(handle, mod, fn, wrapper)