Merge pull request #7595 from siutin/feature/restore-progress

restore the progress from session lost / tab reload
This commit is contained in:
AUTOMATIC1111 2023-04-29 22:13:48 +03:00 committed by GitHub
commit 80987c36f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 141 additions and 18 deletions

View File

@ -22,7 +22,7 @@ titles = {
"\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4cb}": "Apply selected styles to current prompt",
"\u{1f4d2}": "Paste available values into the field", "\u{1f4d2}": "Paste available values into the field",
"\u{1f3b4}": "Show/hide extra networks", "\u{1f3b4}": "Show/hide extra networks",
"\u{1F300}": "Restore progress",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",

View File

@ -362,6 +362,32 @@ function selectCheckpoint(name){
gradioApp().getElementById('change_checkpoint').click() gradioApp().getElementById('change_checkpoint').click()
} }
function restoreProgress (task_tag) {
if (task_tag) {
let successHandler = ({ current_task }) => {
if (current_task) {
showSubmitButtons(task_tag, false)
requestProgress(current_task, gradioApp().getElementById(`${task_tag}_gallery_container`), gradioApp().getElementById(`${task_tag}_gallery`), function(){
showSubmitButtons(task_tag, true)
})
}
}
let errorHandler = e => window.alert(`invalid internal api respsonse. message: ${e}`)
fetch("./internal/current_task")
.then(res => res.json())
.then(successHandler)
.catch(errorHandler)
}
var res = create_submit_args(arguments)
res[0] = 0
return res
}
function currentImg2imgSourceResolution(_, _, scaleBy){ function currentImg2imgSourceResolution(_, _, scaleBy){
var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] img') var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] img')
return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy] return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy]
@ -377,3 +403,4 @@ function updateImg2imgResizeToTextAfterChangingImage(){
return [] return []
} }

View File

@ -4,10 +4,16 @@ import threading
import traceback import traceback
import time import time
import gradio as gr
from modules import shared, progress from modules import shared, progress
queue_lock = threading.Lock() queue_lock = threading.Lock()
queue_lock_condition = threading.Condition(lock=queue_lock)
def wrap_session_call(func):
def f(request: gr.Request, *args, **kwargs):
return func(request, *args, **kwargs)
return f
def wrap_queued_call(func): def wrap_queued_call(func):
def f(*args, **kwargs): def f(*args, **kwargs):
@ -20,29 +26,31 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None): def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs): def f(request: gr.Request, *args, **kwargs):
user = request.username
# if the first argument is a string that says "task(...)", it is treated as a job id # if the first argument is a string that says "task(...)", it is treated as a job id
if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
id_task = args[0] id_task = args[0]
progress.add_task_to_queue(id_task) progress.add_task_to_queue(user, id_task)
else: else:
id_task = None id_task = None
with queue_lock: with queue_lock:
shared.state.begin() shared.state.begin()
progress.start_task(id_task) progress.start_task(user, id_task)
try: try:
res = func(*args, **kwargs) res = func(*args, **kwargs)
finally: finally:
progress.finish_task(id_task) progress.finish_task(user, id_task)
progress.set_last_task_result(user, id_task, res)
shared.state.end() shared.state.end()
return res return res
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) return wrap_session_call(wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True))
def wrap_gradio_call(func, extra_outputs=None, add_stats=False): def wrap_gradio_call(func, extra_outputs=None, add_stats=False):

View File

@ -4,38 +4,84 @@ import time
import gradio as gr import gradio as gr
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional
from fastapi import Depends, Security
from fastapi.security import APIKeyCookie
from modules import call_queue
from modules.shared import opts from modules.shared import opts
import modules.shared as shared import modules.shared as shared
current_task_user = None
current_task = None current_task = None
pending_tasks = {} pending_tasks = {}
finished_tasks = [] finished_tasks = []
def start_task(id_task): def start_task(user, id_task):
global current_task global current_task
global current_task_user
current_task_user = user
current_task = id_task current_task = id_task
pending_tasks.pop(id_task, None) pending_tasks.pop((user, id_task), None)
def finish_task(id_task): def finish_task(user, id_task):
global current_task global current_task
global current_task_user
if current_task == id_task: if current_task == id_task:
current_task = None current_task = None
finished_tasks.append(id_task) if current_task_user == user:
current_task_user = None
finished_tasks.append((user, id_task))
if len(finished_tasks) > 16: if len(finished_tasks) > 16:
finished_tasks.pop(0) finished_tasks.pop(0)
def add_task_to_queue(id_job): def add_task_to_queue(user, id_job):
pending_tasks[id_job] = time.time() pending_tasks[(user, id_job)] = time.time()
last_task_id = None
last_task_result = None
last_task_user = None
def set_last_task_result(user, id_job, result):
global last_task_id
global last_task_result
global last_task_user
last_task_id = id_job
last_task_result = result
last_task_user = user
def restore_progress_call(request: gr.Request):
if current_task is None:
# image, generation_info, html_info, html_log
return tuple(list([None, None, None, None]))
else:
user = request.username
if current_task_user == user:
t_task = current_task
with call_queue.queue_lock_condition:
call_queue.queue_lock_condition.wait_for(lambda: t_task == last_task_id)
return last_task_result
return tuple(list([None, None, None, None]))
class CurrentTaskResponse(BaseModel):
current_task: str = Field(default=None, title="Task ID", description="id of the current progress task")
class ProgressRequest(BaseModel): class ProgressRequest(BaseModel):
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
@ -56,6 +102,21 @@ class ProgressResponse(BaseModel):
def setup_progress_api(app): def setup_progress_api(app):
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
def setup_current_task_api(app):
def get_current_user(token: Optional[str] = Security(APIKeyCookie(name="access-token", auto_error=False))):
return None if token is None else app.tokens.get(token)
def current_task_api(current_user: str = Depends(get_current_user)):
if app.auth is None or current_task_user == current_user:
current_user_task = current_task
else:
current_user_task = None
return CurrentTaskResponse(current_task=current_user_task)
return app.add_api_route("/internal/current_task", current_task_api, methods=["GET"], response_model=CurrentTaskResponse)
def progressapi(req: ProgressRequest): def progressapi(req: ProgressRequest):
active = req.id_task == current_task active = req.id_task == current_task
@ -95,5 +156,4 @@ def progressapi(req: ProgressRequest):
else: else:
live_preview = None live_preview = None
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)

View File

@ -41,6 +41,7 @@ from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text from modules.generation_parameters_copypaste import image_from_url_text
import modules.extras import modules.extras
from modules.progress import restore_progress_call
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
@ -81,6 +82,7 @@ apply_style_symbol = '\U0001f4cb' # 📋
clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️ clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
extra_networks_symbol = '\U0001F3B4' # 🎴 extra_networks_symbol = '\U0001F3B4' # 🎴
switch_values_symbol = '\U000021C5' # ⇅ switch_values_symbol = '\U000021C5' # ⇅
restore_progress_symbol = '\U0001F300' # 🌀
def plaintext_to_html(text): def plaintext_to_html(text):
@ -325,6 +327,7 @@ def create_toprow(is_img2img):
extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply") prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create") save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress")
token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
@ -342,7 +345,7 @@ def create_toprow(is_img2img):
prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button
def setup_progressbar(*args, **kwargs): def setup_progressbar(*args, **kwargs):
@ -459,7 +462,7 @@ def create_ui():
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
@ -591,6 +594,18 @@ def create_ui():
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False) res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
restore_progress_button.click(
fn=restore_progress_call,
_js="() => restoreProgress('txt2img')",
inputs=[],
outputs=[
txt2img_gallery,
generation_info,
html_info,
html_log,
]
)
txt_prompt_img.change( txt_prompt_img.change(
fn=modules.images.image_data, fn=modules.images.image_data,
inputs=[ inputs=[
@ -659,7 +674,7 @@ def create_ui():
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True)
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
@ -951,6 +966,18 @@ def create_ui():
submit.click(**img2img_args) submit.click(**img2img_args)
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False) res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
restore_progress_button.click(
fn=restore_progress_call,
_js="() => restoreProgress('img2img')",
inputs=[],
outputs=[
img2img_gallery,
generation_info,
html_info,
html_log,
]
)
img2img_interrogate.click( img2img_interrogate.click(
fn=lambda *args: process_interrogate(interrogate, *args), fn=lambda *args: process_interrogate(interrogate, *args),
**interrogate_args, **interrogate_args,
@ -1547,7 +1574,7 @@ def create_ui():
gr.HTML(shared.html("licenses.html"), elem_id="licenses") gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages") gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
def unload_sd_weights(): def unload_sd_weights():
modules.sd_models.unload_model_weights() modules.sd_models.unload_model_weights()

View File

@ -339,6 +339,7 @@ def webui():
setup_middleware(app) setup_middleware(app)
modules.progress.setup_progress_api(app) modules.progress.setup_progress_api(app)
modules.progress.setup_current_task_api(app)
if launch_api: if launch_api:
create_api(app) create_api(app)