File size: 835 Bytes
19b3da3
10230ea
 
19b3da3
 
 
10230ea
 
 
 
 
19b3da3
10230ea
 
19b3da3
10230ea
 
19b3da3
 
 
10230ea
 
19b3da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gc
import os
import psutil
import torch


def print_memory_usage():
    process = psutil.Process(os.getpid())
    print(f"Memory usage: {process.memory_info().rss / 1024 ** 2:2f} MB")


def clear_cuda_and_gc():
    print_memory_usage()
    print("Clearing cuda and gc")
    clear_gc()
    clear_cuda()
    print_memory_usage()


def clear_cuda():
    with torch.no_grad():
        torch.cuda.empty_cache()


def clear_gc():
    gc.collect()


def auto_clear_cuda_and_gc(controlnet):
    def auto_clear_cuda_and_gc_wrapper(func):
        def wrapper(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except Exception as e:
                controlnet.cleanup()
                clear_cuda_and_gc()
                raise e

        return wrapper

    return auto_clear_cuda_and_gc_wrapper