diff --git a/modules/call_queue.py b/modules/call_queue.py index 632afcdd..43f6ebe0 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -10,6 +10,11 @@ from modules import shared, progress queue_lock = threading.Lock() queue_lock_condition = threading.Condition(lock=queue_lock) +def wrap_session_call(func): + def f(request: gr.Request, *args, **kwargs): + return func(request, *args, **kwargs) + return f + def wrap_queued_call(func): def f(*args, **kwargs): with queue_lock: @@ -45,21 +50,18 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return res - return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True, add_request=True) + return wrap_session_call(wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)) -def wrap_gradio_call(func, extra_outputs=None, add_stats=False, add_request=False): - def f(request: gr.Request, *args, extra_outputs_array=extra_outputs, **kwargs): +def wrap_gradio_call(func, extra_outputs=None, add_stats=False): + def f(*args, extra_outputs_array=extra_outputs, **kwargs): run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats if run_memmon: shared.mem_mon.monitor() t = time.perf_counter() try: - if add_request: - res = list(func(request, *args, **kwargs)) - else: - res = list(func(*args, **kwargs)) + res = list(func(*args, **kwargs)) except Exception as e: # When printing out our debug argument list, do not print out more than a MB of text max_debug_str_len = 131072 # (1024*1024)/8