replace with #wrap_session_call

This commit is contained in:
siutin 2023-04-17 11:50:08 +08:00
parent 984970068c
commit 3e5b3c79e4

View File

@ -10,6 +10,11 @@ from modules import shared, progress
queue_lock = threading.Lock() queue_lock = threading.Lock()
queue_lock_condition = threading.Condition(lock=queue_lock) queue_lock_condition = threading.Condition(lock=queue_lock)
def wrap_session_call(func):
def f(request: gr.Request, *args, **kwargs):
return func(request, *args, **kwargs)
return f
def wrap_queued_call(func): def wrap_queued_call(func):
def f(*args, **kwargs): def f(*args, **kwargs):
with queue_lock: with queue_lock:
@ -45,20 +50,17 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
return res return res
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True, add_request=True) return wrap_session_call(wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True))
def wrap_gradio_call(func, extra_outputs=None, add_stats=False, add_request=False): def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(request: gr.Request, *args, extra_outputs_array=extra_outputs, **kwargs): def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon: if run_memmon:
shared.mem_mon.monitor() shared.mem_mon.monitor()
t = time.perf_counter() t = time.perf_counter()
try: try:
if add_request:
res = list(func(request, *args, **kwargs))
else:
res = list(func(*args, **kwargs)) res = list(func(*args, **kwargs))
except Exception as e: except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text # When printing out our debug argument list, do not print out more than a MB of text