Add VRAM monitoring
This commit is contained in:
parent
1fc1c537c7
commit
ed6787ca2f
77
modules/memmon.py
Normal file
77
modules/memmon.py
Normal file
@ -0,0 +1,77 @@
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MemUsageMonitor(threading.Thread):
|
||||
run_flag = None
|
||||
device = None
|
||||
disabled = False
|
||||
opts = None
|
||||
data = None
|
||||
|
||||
def __init__(self, name, device, opts):
|
||||
threading.Thread.__init__(self)
|
||||
self.name = name
|
||||
self.device = device
|
||||
self.opts = opts
|
||||
|
||||
self.daemon = True
|
||||
self.run_flag = threading.Event()
|
||||
self.data = defaultdict(int)
|
||||
|
||||
def run(self):
|
||||
if self.disabled:
|
||||
return
|
||||
|
||||
while True:
|
||||
self.run_flag.wait()
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
self.data.clear()
|
||||
|
||||
if self.opts.memmon_poll_rate <= 0:
|
||||
self.run_flag.clear()
|
||||
continue
|
||||
|
||||
self.data["min_free"] = torch.cuda.mem_get_info()[0]
|
||||
|
||||
while self.run_flag.is_set():
|
||||
free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
|
||||
self.data["min_free"] = min(self.data["min_free"], free)
|
||||
|
||||
time.sleep(1 / self.opts.memmon_poll_rate)
|
||||
|
||||
def dump_debug(self):
|
||||
print(self, 'recorded data:')
|
||||
for k, v in self.read().items():
|
||||
print(k, -(v // -(1024 ** 2)))
|
||||
|
||||
print(self, 'raw torch memory stats:')
|
||||
tm = torch.cuda.memory_stats(self.device)
|
||||
for k, v in tm.items():
|
||||
if 'bytes' not in k:
|
||||
continue
|
||||
print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
|
||||
|
||||
print(torch.cuda.memory_summary())
|
||||
|
||||
def monitor(self):
|
||||
self.run_flag.set()
|
||||
|
||||
def read(self):
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
self.data["total"] = total
|
||||
|
||||
torch_stats = torch.cuda.memory_stats(self.device)
|
||||
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
|
||||
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
|
||||
self.data["system_peak"] = total - self.data["min_free"]
|
||||
|
||||
return self.data
|
||||
|
||||
def stop(self):
|
||||
self.run_flag.clear()
|
||||
return self.read()
|
@ -12,6 +12,7 @@ from modules.paths import script_path, sd_path
|
||||
from modules.devices import get_optimal_device
|
||||
import modules.styles
|
||||
import modules.interrogate
|
||||
import modules.memmon
|
||||
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
if not os.path.exists(sd_model_file):
|
||||
@ -138,6 +139,7 @@ class Options:
|
||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
|
||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."),
|
||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step":1}),
|
||||
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
||||
@ -217,3 +219,6 @@ class TotalTQDM:
|
||||
|
||||
|
||||
total_tqdm = TotalTQDM()
|
||||
|
||||
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
|
||||
mem_mon.start()
|
||||
|
@ -119,6 +119,7 @@ def save_files(js_data, images, index):
|
||||
|
||||
def wrap_gradio_call(func):
|
||||
def f(*args, **kwargs):
|
||||
shared.mem_mon.monitor()
|
||||
t = time.perf_counter()
|
||||
|
||||
try:
|
||||
@ -135,8 +136,19 @@ def wrap_gradio_call(func):
|
||||
|
||||
elapsed = time.perf_counter() - t
|
||||
|
||||
mem_stats = {k:-(v//-(1024*1024)) for k,v in shared.mem_mon.stop().items()}
|
||||
active_peak = mem_stats['active_peak']
|
||||
reserved_peak = mem_stats['reserved_peak']
|
||||
sys_peak = '?' if opts.memmon_poll_rate <= 0 else mem_stats['system_peak']
|
||||
sys_total = mem_stats['total']
|
||||
sys_pct = '?' if opts.memmon_poll_rate <= 0 else round(sys_peak/sys_total * 100, 2)
|
||||
vram_tooltip = "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.
" \
|
||||
"Torch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.
" \
|
||||
"Sys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%)."
|
||||
|
||||
# last item is always HTML
|
||||
res[-1] = res[-1] + f"<p class='performance'>Time taken: {elapsed:.2f}s</p>"
|
||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>" \
|
||||
f"<p class='vram' title='{vram_tooltip}'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p></div>"
|
||||
|
||||
shared.state.interrupted = False
|
||||
|
||||
|
18
style.css
18
style.css
@ -1,5 +1,21 @@
|
||||
.output-html p {margin: 0 0.5em;}
|
||||
.performance { font-size: 0.85em; color: #444; }
|
||||
|
||||
.performance {
|
||||
font-size: 0.85em;
|
||||
color: #444;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.performance .time {
|
||||
margin-right: 0;
|
||||
}
|
||||
|
||||
.performance .vram {
|
||||
margin-left: 0;
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
#generate{
|
||||
min-height: 4.5em;
|
||||
|
Loading…
Reference in New Issue
Block a user