multi users support

This commit is contained in:
siutin 2023-04-17 01:06:28 +08:00
parent 70ab21e67d
commit 984970068c
3 changed files with 59 additions and 28 deletions

View File

@ -4,6 +4,7 @@ import threading
import traceback
import time
import gradio as gr
from modules import shared, progress
queue_lock = threading.Lock()
@ -20,40 +21,44 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
def f(request: gr.Request, *args, **kwargs):
user = request.username
# if the first argument is a string that says "task(...)", it is treated as a job id
if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
id_task = args[0]
progress.add_task_to_queue(id_task)
progress.add_task_to_queue(user, id_task)
else:
id_task = None
with queue_lock:
shared.state.begin()
progress.start_task(id_task)
progress.start_task(user, id_task)
try:
res = func(*args, **kwargs)
finally:
progress.finish_task(id_task)
progress.set_last_task_result(id_task, res)
progress.finish_task(user, id_task)
progress.set_last_task_result(user, id_task, res)
shared.state.end()
return res
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True, add_request=True)
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
def wrap_gradio_call(func, extra_outputs=None, add_stats=False, add_request=False):
def f(request: gr.Request, *args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
shared.mem_mon.monitor()
t = time.perf_counter()
try:
if add_request:
res = list(func(request, *args, **kwargs))
else:
res = list(func(*args, **kwargs))
except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text

View File

@ -4,7 +4,9 @@ import time
import gradio as gr
from pydantic import BaseModel, Field
from typing import List
from typing import Optional
from fastapi import Depends, Security
from fastapi.security import APIKeyCookie
from modules import call_queue
from modules.shared import opts
@ -12,57 +14,71 @@ from modules.shared import opts
import modules.shared as shared
current_task_user = None
current_task = None
pending_tasks = {}
finished_tasks = []
def start_task(id_task):
def start_task(user, id_task):
global current_task
global current_task_user
current_task_user = user
current_task = id_task
pending_tasks.pop(id_task, None)
pending_tasks.pop((user, id_task), None)
def finish_task(id_task):
def finish_task(user, id_task):
global current_task
global current_task_user
if current_task == id_task:
current_task = None
finished_tasks.append(id_task)
if current_task_user == user:
current_task_user = None
finished_tasks.append((user, id_task))
if len(finished_tasks) > 16:
finished_tasks.pop(0)
def add_task_to_queue(id_job):
pending_tasks[id_job] = time.time()
def add_task_to_queue(user, id_job):
pending_tasks[(user, id_job)] = time.time()
last_task_id = None
last_task_result = None
last_task_user = None
def set_last_task_result(user, id_job, result):
def set_last_task_result(id_job, result):
global last_task_id
global last_task_result
global last_task_user
last_task_id = id_job
last_task_result = result
last_task_user = user
def restore_progress_call():
def restore_progress_call(request: gr.Request):
if current_task is None:
# image, generation_info, html_info, html_log
return tuple(list([None, None, None, None]))
else:
user = request.username
if current_task_user == user:
t_task = current_task
with call_queue.queue_lock_condition:
call_queue.queue_lock_condition.wait_for(lambda: t_task == last_task_id)
return last_task_result
return tuple(list([None, None, None, None]))
class CurrentTaskResponse(BaseModel):
current_task: str = Field(default=None, title="Task ID", description="id of the current progress task")
@ -87,6 +103,19 @@ def setup_progress_api(app):
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
def setup_current_task_api(app):
def get_current_user(token: Optional[str] = Security(APIKeyCookie(name="access-token", auto_error=False))):
return None if token is None else app.tokens.get(token)
def current_task_api(current_user: str = Depends(get_current_user)):
if app.auth is None or current_task_user == current_user:
current_user_task = current_task
else:
current_user_task = None
return CurrentTaskResponse(current_task=current_user_task)
return app.add_api_route("/internal/current_task", current_task_api, methods=["GET"], response_model=CurrentTaskResponse)
def progressapi(req: ProgressRequest):
@ -128,6 +157,3 @@ def progressapi(req: ProgressRequest):
live_preview = None
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
def current_task_api():
return CurrentTaskResponse(current_task=current_task)

View File

@ -582,7 +582,7 @@ def create_ui():
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
restore_progress_button.click(
fn=lambda: restore_progress_call(),
fn=restore_progress_call,
_js="() => restoreProgress('txt2img')",
inputs=[],
outputs=[
@ -914,7 +914,7 @@ def create_ui():
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
restore_progress_button.click(
fn=lambda: restore_progress_call(),
fn=restore_progress_call,
_js="() => restoreProgress('img2img')",
inputs=[],
outputs=[