Replace state.need_restart with state.server_command + replace poll loop with signal

This commit is contained in:
Aarni Koskela 2023-05-11 23:46:45 +03:00
parent 4b07f2f584
commit 85b4f89926
4 changed files with 68 additions and 26 deletions

View File

@ -2,6 +2,7 @@ import datetime
import json import json
import os import os
import sys import sys
import threading
import time import time
import gradio as gr import gradio as gr
@ -110,8 +111,47 @@ class State:
id_live_preview = 0 id_live_preview = 0
textinfo = None textinfo = None
time_start = None time_start = None
need_restart = False
server_start = None server_start = None
_server_command_signal = threading.Event()
_server_command: str | None = None
@property
def need_restart(self) -> bool:
# Compatibility getter for need_restart.
return self.server_command == "restart"
@need_restart.setter
def need_restart(self, value: bool) -> None:
# Compatibility setter for need_restart.
if value:
self.server_command = "restart"
@property
def server_command(self):
return self._server_command
@server_command.setter
def server_command(self, value: str | None) -> None:
"""
Set the server command to `value` and signal that it's been set.
"""
self._server_command = value
self._server_command_signal.set()
def wait_for_server_command(self, timeout: float | None = None) -> str | None:
"""
Wait for server command to get set; return and clear the value and signal.
"""
if self._server_command_signal.wait(timeout):
self._server_command_signal.clear()
req = self._server_command
self._server_command = None
return req
return None
def request_restart(self) -> None:
self.interrupt()
self.server_command = True
def skip(self): def skip(self):
self.skipped = True self.skipped = True

View File

@ -1609,12 +1609,8 @@ def create_ui():
outputs=[] outputs=[]
) )
def request_restart():
shared.state.interrupt()
shared.state.need_restart = True
restart_gradio.click( restart_gradio.click(
fn=request_restart, fn=shared.state.request_restart,
_js='restart_reload', _js='restart_reload',
inputs=[], inputs=[],
outputs=[], outputs=[],

View File

@ -52,9 +52,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
shared.opts.disabled_extensions = disabled shared.opts.disabled_extensions = disabled
shared.opts.disable_all_extensions = disable_all shared.opts.disable_all_extensions = disable_all
shared.opts.save(shared.config_filename) shared.opts.save(shared.config_filename)
shared.state.request_restart()
shared.state.interrupt()
shared.state.need_restart = True
def save_config_state(name): def save_config_state(name):
@ -92,8 +90,7 @@ def restore_config_state(confirmed, config_state_name, restore_type):
if restore_type == "webui" or restore_type == "both": if restore_type == "webui" or restore_type == "both":
config_states.restore_webui_config(config_state) config_states.restore_webui_config(config_state)
shared.state.interrupt() shared.state.request_restart()
shared.state.need_restart = True
return "" return ""

View File

@ -234,6 +234,9 @@ def initialize():
print(f'Interrupted with signal {sig} in {frame}') print(f'Interrupted with signal {sig} in {frame}')
os._exit(0) os._exit(0)
if not os.environ.get("COVERAGE_RUN"):
# Don't install the immediate-quit handler when running under coverage,
# as then the coverage report won't be generated.
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
@ -255,19 +258,6 @@ def create_api(app):
return api return api
def wait_on_server(demo=None):
while 1:
time.sleep(0.5)
if shared.state.need_restart:
shared.state.need_restart = False
time.sleep(0.5)
demo.close()
time.sleep(0.5)
modules.script_callbacks.app_reload_callback()
break
def api_only(): def api_only():
initialize() initialize()
@ -328,6 +318,7 @@ def webui():
inbrowser=cmd_opts.autolaunch, inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True prevent_thread_lock=True
) )
# after initial launch, disable --autolaunch for subsequent restarts # after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False cmd_opts.autolaunch = False
@ -359,8 +350,26 @@ def webui():
redirector.get("/") redirector.get("/")
gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}") gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
wait_on_server(shared.demo) try:
while True:
server_command = shared.state.wait_for_server_command(timeout=5)
if server_command:
if server_command in ("stop", "restart"):
break
else:
print(f"Unknown server command: {server_command}")
except KeyboardInterrupt:
server_command = "stop"
if server_command == "stop":
# If we catch a keyboard interrupt, we want to stop the server and exit.
print('Caught KeyboardInterrupt, stopping...')
shared.demo.close()
break
print('Restarting UI...') print('Restarting UI...')
shared.demo.close()
time.sleep(0.5)
modules.script_callbacks.app_reload_callback()
startup_timer.reset() startup_timer.reset()