multi users support

This commit is contained in:
siutin 2023-04-17 01:06:28 +08:00
parent 70ab21e67d
commit 984970068c
3 changed files with 59 additions and 28 deletions

View File

@ -4,6 +4,7 @@ import threading
import traceback import traceback
import time import time
import gradio as gr
from modules import shared, progress from modules import shared, progress
queue_lock = threading.Lock() queue_lock = threading.Lock()
@ -20,40 +21,44 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None): def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs): def f(request: gr.Request, *args, **kwargs):
user = request.username
# if the first argument is a string that says "task(...)", it is treated as a job id # if the first argument is a string that says "task(...)", it is treated as a job id
if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
id_task = args[0] id_task = args[0]
progress.add_task_to_queue(id_task) progress.add_task_to_queue(user, id_task)
else: else:
id_task = None id_task = None
with queue_lock: with queue_lock:
shared.state.begin() shared.state.begin()
progress.start_task(id_task) progress.start_task(user, id_task)
try: try:
res = func(*args, **kwargs) res = func(*args, **kwargs)
finally: finally:
progress.finish_task(id_task) progress.finish_task(user, id_task)
progress.set_last_task_result(id_task, res) progress.set_last_task_result(user, id_task, res)
shared.state.end() shared.state.end()
return res return res
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True, add_request=True)
def wrap_gradio_call(func, extra_outputs=None, add_stats=False): def wrap_gradio_call(func, extra_outputs=None, add_stats=False, add_request=False):
def f(*args, extra_outputs_array=extra_outputs, **kwargs): def f(request: gr.Request, *args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon: if run_memmon:
shared.mem_mon.monitor() shared.mem_mon.monitor()
t = time.perf_counter() t = time.perf_counter()
try: try:
if add_request:
res = list(func(request, *args, **kwargs))
else:
res = list(func(*args, **kwargs)) res = list(func(*args, **kwargs))
except Exception as e: except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text # When printing out our debug argument list, do not print out more than a MB of text

View File

@ -4,7 +4,9 @@ import time
import gradio as gr import gradio as gr
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List from typing import Optional
from fastapi import Depends, Security
from fastapi.security import APIKeyCookie
from modules import call_queue from modules import call_queue
from modules.shared import opts from modules.shared import opts
@ -12,57 +14,71 @@ from modules.shared import opts
import modules.shared as shared import modules.shared as shared
current_task_user = None
current_task = None current_task = None
pending_tasks = {} pending_tasks = {}
finished_tasks = [] finished_tasks = []
def start_task(id_task): def start_task(user, id_task):
global current_task global current_task
global current_task_user
current_task_user = user
current_task = id_task current_task = id_task
pending_tasks.pop(id_task, None) pending_tasks.pop((user, id_task), None)
def finish_task(id_task): def finish_task(user, id_task):
global current_task global current_task
global current_task_user
if current_task == id_task: if current_task == id_task:
current_task = None current_task = None
finished_tasks.append(id_task) if current_task_user == user:
current_task_user = None
finished_tasks.append((user, id_task))
if len(finished_tasks) > 16: if len(finished_tasks) > 16:
finished_tasks.pop(0) finished_tasks.pop(0)
def add_task_to_queue(id_job): def add_task_to_queue(user, id_job):
pending_tasks[id_job] = time.time() pending_tasks[(user, id_job)] = time.time()
last_task_id = None last_task_id = None
last_task_result = None last_task_result = None
last_task_user = None
def set_last_task_result(user, id_job, result):
def set_last_task_result(id_job, result):
global last_task_id global last_task_id
global last_task_result global last_task_result
global last_task_user
last_task_id = id_job last_task_id = id_job
last_task_result = result last_task_result = result
last_task_user = user
def restore_progress_call(): def restore_progress_call(request: gr.Request):
if current_task is None: if current_task is None:
# image, generation_info, html_info, html_log # image, generation_info, html_info, html_log
return tuple(list([None, None, None, None])) return tuple(list([None, None, None, None]))
else: else:
user = request.username
if current_task_user == user:
t_task = current_task t_task = current_task
with call_queue.queue_lock_condition: with call_queue.queue_lock_condition:
call_queue.queue_lock_condition.wait_for(lambda: t_task == last_task_id) call_queue.queue_lock_condition.wait_for(lambda: t_task == last_task_id)
return last_task_result return last_task_result
return tuple(list([None, None, None, None]))
class CurrentTaskResponse(BaseModel): class CurrentTaskResponse(BaseModel):
current_task: str = Field(default=None, title="Task ID", description="id of the current progress task") current_task: str = Field(default=None, title="Task ID", description="id of the current progress task")
@ -87,6 +103,19 @@ def setup_progress_api(app):
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
def setup_current_task_api(app): def setup_current_task_api(app):
def get_current_user(token: Optional[str] = Security(APIKeyCookie(name="access-token", auto_error=False))):
return None if token is None else app.tokens.get(token)
def current_task_api(current_user: str = Depends(get_current_user)):
if app.auth is None or current_task_user == current_user:
current_user_task = current_task
else:
current_user_task = None
return CurrentTaskResponse(current_task=current_user_task)
return app.add_api_route("/internal/current_task", current_task_api, methods=["GET"], response_model=CurrentTaskResponse) return app.add_api_route("/internal/current_task", current_task_api, methods=["GET"], response_model=CurrentTaskResponse)
def progressapi(req: ProgressRequest): def progressapi(req: ProgressRequest):
@ -128,6 +157,3 @@ def progressapi(req: ProgressRequest):
live_preview = None live_preview = None
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
def current_task_api():
return CurrentTaskResponse(current_task=current_task)

View File

@ -582,7 +582,7 @@ def create_ui():
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False) res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
restore_progress_button.click( restore_progress_button.click(
fn=lambda: restore_progress_call(), fn=restore_progress_call,
_js="() => restoreProgress('txt2img')", _js="() => restoreProgress('txt2img')",
inputs=[], inputs=[],
outputs=[ outputs=[
@ -914,7 +914,7 @@ def create_ui():
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False) res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
restore_progress_button.click( restore_progress_button.click(
fn=lambda: restore_progress_call(), fn=restore_progress_call,
_js="() => restoreProgress('img2img')", _js="() => restoreProgress('img2img')",
inputs=[], inputs=[],
outputs=[ outputs=[