import gc | |
import os | |
import psutil | |
import torch | |
def print_memory_usage(): | |
process = psutil.Process(os.getpid()) | |
print(f"Memory usage: {process.memory_info().rss / 1024 ** 2:2f} MB") | |
def clear_cuda_and_gc(): | |
print_memory_usage() | |
print("Clearing cuda and gc") | |
clear_gc() | |
clear_cuda() | |
print_memory_usage() | |
def clear_cuda(): | |
with torch.no_grad(): | |
torch.cuda.empty_cache() | |
def clear_gc(): | |
gc.collect() | |
def auto_clear_cuda_and_gc(controlnet): | |
def auto_clear_cuda_and_gc_wrapper(func): | |
def wrapper(*args, **kwargs): | |
try: | |
return func(*args, **kwargs) | |
except Exception as e: | |
controlnet.cleanup() | |
clear_cuda_and_gc() | |
raise e | |
return wrapper | |
return auto_clear_cuda_and_gc_wrapper | |