Add VRAM monitoring
This commit is contained in:
parent
1fc1c537c7
commit
ed6787ca2f
77
modules/memmon.py
Normal file
77
modules/memmon.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class MemUsageMonitor(threading.Thread):
|
||||||
|
run_flag = None
|
||||||
|
device = None
|
||||||
|
disabled = False
|
||||||
|
opts = None
|
||||||
|
data = None
|
||||||
|
|
||||||
|
def __init__(self, name, device, opts):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.name = name
|
||||||
|
self.device = device
|
||||||
|
self.opts = opts
|
||||||
|
|
||||||
|
self.daemon = True
|
||||||
|
self.run_flag = threading.Event()
|
||||||
|
self.data = defaultdict(int)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
if self.disabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
while True:
|
||||||
|
self.run_flag.wait()
|
||||||
|
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
self.data.clear()
|
||||||
|
|
||||||
|
if self.opts.memmon_poll_rate <= 0:
|
||||||
|
self.run_flag.clear()
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.data["min_free"] = torch.cuda.mem_get_info()[0]
|
||||||
|
|
||||||
|
while self.run_flag.is_set():
|
||||||
|
free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
|
||||||
|
self.data["min_free"] = min(self.data["min_free"], free)
|
||||||
|
|
||||||
|
time.sleep(1 / self.opts.memmon_poll_rate)
|
||||||
|
|
||||||
|
def dump_debug(self):
|
||||||
|
print(self, 'recorded data:')
|
||||||
|
for k, v in self.read().items():
|
||||||
|
print(k, -(v // -(1024 ** 2)))
|
||||||
|
|
||||||
|
print(self, 'raw torch memory stats:')
|
||||||
|
tm = torch.cuda.memory_stats(self.device)
|
||||||
|
for k, v in tm.items():
|
||||||
|
if 'bytes' not in k:
|
||||||
|
continue
|
||||||
|
print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
|
||||||
|
|
||||||
|
print(torch.cuda.memory_summary())
|
||||||
|
|
||||||
|
def monitor(self):
|
||||||
|
self.run_flag.set()
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
free, total = torch.cuda.mem_get_info()
|
||||||
|
self.data["total"] = total
|
||||||
|
|
||||||
|
torch_stats = torch.cuda.memory_stats(self.device)
|
||||||
|
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
|
||||||
|
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
|
||||||
|
self.data["system_peak"] = total - self.data["min_free"]
|
||||||
|
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.run_flag.clear()
|
||||||
|
return self.read()
|
@ -12,6 +12,7 @@ from modules.paths import script_path, sd_path
|
|||||||
from modules.devices import get_optimal_device
|
from modules.devices import get_optimal_device
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.interrogate
|
import modules.interrogate
|
||||||
|
import modules.memmon
|
||||||
|
|
||||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||||
if not os.path.exists(sd_model_file):
|
if not os.path.exists(sd_model_file):
|
||||||
@ -138,6 +139,7 @@ class Options:
|
|||||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||||
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
|
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
|
||||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."),
|
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."),
|
||||||
|
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step":1}),
|
||||||
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
||||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||||
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
||||||
@ -217,3 +219,6 @@ class TotalTQDM:
|
|||||||
|
|
||||||
|
|
||||||
total_tqdm = TotalTQDM()
|
total_tqdm = TotalTQDM()
|
||||||
|
|
||||||
|
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
|
||||||
|
mem_mon.start()
|
||||||
|
@ -119,6 +119,7 @@ def save_files(js_data, images, index):
|
|||||||
|
|
||||||
def wrap_gradio_call(func):
|
def wrap_gradio_call(func):
|
||||||
def f(*args, **kwargs):
|
def f(*args, **kwargs):
|
||||||
|
shared.mem_mon.monitor()
|
||||||
t = time.perf_counter()
|
t = time.perf_counter()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -135,8 +136,19 @@ def wrap_gradio_call(func):
|
|||||||
|
|
||||||
elapsed = time.perf_counter() - t
|
elapsed = time.perf_counter() - t
|
||||||
|
|
||||||
|
mem_stats = {k:-(v//-(1024*1024)) for k,v in shared.mem_mon.stop().items()}
|
||||||
|
active_peak = mem_stats['active_peak']
|
||||||
|
reserved_peak = mem_stats['reserved_peak']
|
||||||
|
sys_peak = '?' if opts.memmon_poll_rate <= 0 else mem_stats['system_peak']
|
||||||
|
sys_total = mem_stats['total']
|
||||||
|
sys_pct = '?' if opts.memmon_poll_rate <= 0 else round(sys_peak/sys_total * 100, 2)
|
||||||
|
vram_tooltip = "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.
" \
|
||||||
|
"Torch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.
" \
|
||||||
|
"Sys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%)."
|
||||||
|
|
||||||
# last item is always HTML
|
# last item is always HTML
|
||||||
res[-1] = res[-1] + f"<p class='performance'>Time taken: {elapsed:.2f}s</p>"
|
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>" \
|
||||||
|
f"<p class='vram' title='{vram_tooltip}'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p></div>"
|
||||||
|
|
||||||
shared.state.interrupted = False
|
shared.state.interrupted = False
|
||||||
|
|
||||||
|
18
style.css
18
style.css
@ -1,5 +1,21 @@
|
|||||||
.output-html p {margin: 0 0.5em;}
|
.output-html p {margin: 0 0.5em;}
|
||||||
.performance { font-size: 0.85em; color: #444; }
|
|
||||||
|
.performance {
|
||||||
|
font-size: 0.85em;
|
||||||
|
color: #444;
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.performance .time {
|
||||||
|
margin-right: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.performance .vram {
|
||||||
|
margin-left: 0;
|
||||||
|
text-align: right;
|
||||||
|
}
|
||||||
|
|
||||||
#generate{
|
#generate{
|
||||||
min-height: 4.5em;
|
min-height: 4.5em;
|
||||||
|
Loading…
Reference in New Issue
Block a user