add an internal API for obtaining current task id

This commit is contained in:
siutin 2023-02-03 03:13:03 +08:00
parent 22bcc7be42
commit dbca512154
2 changed files with 9 additions and 0 deletions

View File

@ -4,6 +4,7 @@ import time
import gradio as gr import gradio as gr
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List
from modules.shared import opts from modules.shared import opts
@ -37,6 +38,9 @@ def add_task_to_queue(id_job):
pending_tasks[id_job] = time.time() pending_tasks[id_job] = time.time()
class CurrentTaskResponse(BaseModel):
current_task: str = Field(default=None, title="Task ID", description="id of the current progress task")
class ProgressRequest(BaseModel): class ProgressRequest(BaseModel):
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image") id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
@ -56,6 +60,8 @@ class ProgressResponse(BaseModel):
def setup_progress_api(app): def setup_progress_api(app):
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
def setup_current_task_api(app):
return app.add_api_route("/internal/current_task", current_task_api, methods=["GET"], response_model=CurrentTaskResponse)
def progressapi(req: ProgressRequest): def progressapi(req: ProgressRequest):
active = req.id_task == current_task active = req.id_task == current_task
@ -97,3 +103,5 @@ def progressapi(req: ProgressRequest):
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
def current_task_api():
return CurrentTaskResponse(current_task=current_task)

View File

@ -279,6 +279,7 @@ def webui():
setup_middleware(app) setup_middleware(app)
modules.progress.setup_progress_api(app) modules.progress.setup_progress_api(app)
modules.progress.setup_current_task_api(app)
if launch_api: if launch_api:
create_api(app) create_api(app)