File size: 2,755 Bytes
ea58bad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import threading
import time
from collections import defaultdict
import torch
class MemUsageMonitor(threading.Thread):
run_flag = None
device = None
disabled = False
opts = None
data = None
def __init__(self, name, device, opts):
threading.Thread.__init__(self)
self.name = name
self.device = device
self.opts = opts
self.daemon = True
self.run_flag = threading.Event()
self.data = defaultdict(int)
try:
self.cuda_mem_get_info()
torch.cuda.memory_stats(self.device)
except Exception as e: # AMD or whatever
print(f"Warning: caught exception '{e}', memory monitor disabled")
self.disabled = True
def cuda_mem_get_info(self):
index = self.device.index if self.device.index is not None else torch.cuda.current_device()
return torch.cuda.mem_get_info(index)
def run(self):
if self.disabled:
return
while True:
self.run_flag.wait()
torch.cuda.reset_peak_memory_stats()
self.data.clear()
if self.opts.memmon_poll_rate <= 0:
self.run_flag.clear()
continue
self.data["min_free"] = self.cuda_mem_get_info()[0]
while self.run_flag.is_set():
free, total = self.cuda_mem_get_info()
self.data["min_free"] = min(self.data["min_free"], free)
time.sleep(1 / self.opts.memmon_poll_rate)
def dump_debug(self):
print(self, 'recorded data:')
for k, v in self.read().items():
print(k, -(v // -(1024 ** 2)))
print(self, 'raw torch memory stats:')
tm = torch.cuda.memory_stats(self.device)
for k, v in tm.items():
if 'bytes' not in k:
continue
print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
print(torch.cuda.memory_summary())
def monitor(self):
self.run_flag.set()
def read(self):
if not self.disabled:
free, total = self.cuda_mem_get_info()
self.data["free"] = free
self.data["total"] = total
torch_stats = torch.cuda.memory_stats(self.device)
self.data["active"] = torch_stats["active.all.current"]
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
self.data["system_peak"] = total - self.data["min_free"]
return self.data
def stop(self):
self.run_flag.clear()
return self.read()
|