From ed6787ca2fe950f633a925ccb0467eafd4ec0f43 Mon Sep 17 00:00:00 2001 From: EyeDeck Date: Sat, 17 Sep 2022 00:49:31 -0400 Subject: [PATCH] Add VRAM monitoring --- modules/memmon.py | 77 +++++++++++++++++++++++++++++++++++++++++++++++ modules/shared.py | 5 +++ modules/ui.py | 14 ++++++++- style.css | 18 ++++++++++- 4 files changed, 112 insertions(+), 2 deletions(-) create mode 100644 modules/memmon.py diff --git a/modules/memmon.py b/modules/memmon.py new file mode 100644 index 00000000..f2cac841 --- /dev/null +++ b/modules/memmon.py @@ -0,0 +1,77 @@ +import threading +import time +from collections import defaultdict + +import torch + + +class MemUsageMonitor(threading.Thread): + run_flag = None + device = None + disabled = False + opts = None + data = None + + def __init__(self, name, device, opts): + threading.Thread.__init__(self) + self.name = name + self.device = device + self.opts = opts + + self.daemon = True + self.run_flag = threading.Event() + self.data = defaultdict(int) + + def run(self): + if self.disabled: + return + + while True: + self.run_flag.wait() + + torch.cuda.reset_peak_memory_stats() + self.data.clear() + + if self.opts.memmon_poll_rate <= 0: + self.run_flag.clear() + continue + + self.data["min_free"] = torch.cuda.mem_get_info()[0] + + while self.run_flag.is_set(): + free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug? + self.data["min_free"] = min(self.data["min_free"], free) + + time.sleep(1 / self.opts.memmon_poll_rate) + + def dump_debug(self): + print(self, 'recorded data:') + for k, v in self.read().items(): + print(k, -(v // -(1024 ** 2))) + + print(self, 'raw torch memory stats:') + tm = torch.cuda.memory_stats(self.device) + for k, v in tm.items(): + if 'bytes' not in k: + continue + print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2))) + + print(torch.cuda.memory_summary()) + + def monitor(self): + self.run_flag.set() + + def read(self): + free, total = torch.cuda.mem_get_info() + self.data["total"] = total + + torch_stats = torch.cuda.memory_stats(self.device) + self.data["active_peak"] = torch_stats["active_bytes.all.peak"] + self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] + self.data["system_peak"] = total - self.data["min_free"] + + return self.data + + def stop(self): + self.run_flag.clear() + return self.read() diff --git a/modules/shared.py b/modules/shared.py index da56b6ae..4f877036 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -12,6 +12,7 @@ from modules.paths import script_path, sd_path from modules.devices import get_optimal_device import modules.styles import modules.interrogate +import modules.memmon sd_model_file = os.path.join(script_path, 'model.ckpt') if not os.path.exists(sd_model_file): @@ -138,6 +139,7 @@ class Options: "show_progressbar": OptionInfo(True, "Show progressbar"), "show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."), + "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step":1}), "face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}), "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."), @@ -217,3 +219,6 @@ class TotalTQDM: total_tqdm = TotalTQDM() + +mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts) +mem_mon.start() diff --git a/modules/ui.py b/modules/ui.py index 738ac945..01b2ba85 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -119,6 +119,7 @@ def save_files(js_data, images, index): def wrap_gradio_call(func): def f(*args, **kwargs): + shared.mem_mon.monitor() t = time.perf_counter() try: @@ -135,8 +136,19 @@ def wrap_gradio_call(func): elapsed = time.perf_counter() - t + mem_stats = {k:-(v//-(1024*1024)) for k,v in shared.mem_mon.stop().items()} + active_peak = mem_stats['active_peak'] + reserved_peak = mem_stats['reserved_peak'] + sys_peak = '?' if opts.memmon_poll_rate <= 0 else mem_stats['system_peak'] + sys_total = mem_stats['total'] + sys_pct = '?' if opts.memmon_poll_rate <= 0 else round(sys_peak/sys_total * 100, 2) + vram_tooltip = "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data. " \ + "Torch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data. " \ + "Sys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%)." + # last item is always HTML - res[-1] = res[-1] + f"

Time taken: {elapsed:.2f}s

" + res[-1] += f"

Time taken: {elapsed:.2f}s

" \ + f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" shared.state.interrupted = False diff --git a/style.css b/style.css index d41c098c..67ce8550 100644 --- a/style.css +++ b/style.css @@ -1,5 +1,21 @@ .output-html p {margin: 0 0.5em;} -.performance { font-size: 0.85em; color: #444; } + +.performance { + font-size: 0.85em; + color: #444; + display: flex; + justify-content: space-between; + white-space: nowrap; +} + +.performance .time { + margin-right: 0; +} + +.performance .vram { + margin-left: 0; + text-align: right; +} #generate{ min-height: 4.5em;