_
File size: 1,207 Bytes
da3eeba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import gc
import inspect
import threading
from functools import wraps

import torch

from ia_check_versions import ia_check_versions

model_access_sem = threading.Semaphore(1)


def torch_gc():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    if ia_check_versions.torch_mps_is_available:
        if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
            torch.mps.empty_cache()


def clear_cache():
    gc.collect()
    torch_gc()


def post_clear_cache(sem):
    with sem:
        gc.collect()
        torch_gc()


def async_post_clear_cache():
    thread = threading.Thread(target=post_clear_cache, args=(model_access_sem,))
    thread.start()


def clear_cache_decorator(func):
    @wraps(func)
    def yield_wrapper(*args, **kwargs):
        clear_cache()
        yield from func(*args, **kwargs)
        clear_cache()

    @wraps(func)
    def wrapper(*args, **kwargs):
        clear_cache()
        res = func(*args, **kwargs)
        clear_cache()
        return res

    if inspect.isgeneratorfunction(func):
        return yield_wrapper
    else:
        return wrapper