Merge pull request #10458 from akx/graceful-stop

Graceful server stopping
This commit is contained in:
AUTOMATIC1111 2023-05-17 18:45:40 +03:00 committed by GitHub
commit f6c06e3ed2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 79 additions and 27 deletions

View File

@ -103,3 +103,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')

View File

@ -2,6 +2,7 @@ import datetime
import json import json
import os import os
import sys import sys
import threading
import time import time
import gradio as gr import gradio as gr
@ -110,8 +111,47 @@ class State:
id_live_preview = 0 id_live_preview = 0
textinfo = None textinfo = None
time_start = None time_start = None
need_restart = False
server_start = None server_start = None
_server_command_signal = threading.Event()
_server_command: str | None = None
@property
def need_restart(self) -> bool:
# Compatibility getter for need_restart.
return self.server_command == "restart"
@need_restart.setter
def need_restart(self, value: bool) -> None:
# Compatibility setter for need_restart.
if value:
self.server_command = "restart"
@property
def server_command(self):
return self._server_command
@server_command.setter
def server_command(self, value: str | None) -> None:
"""
Set the server command to `value` and signal that it's been set.
"""
self._server_command = value
self._server_command_signal.set()
def wait_for_server_command(self, timeout: float | None = None) -> str | None:
"""
Wait for server command to get set; return and clear the value and signal.
"""
if self._server_command_signal.wait(timeout):
self._server_command_signal.clear()
req = self._server_command
self._server_command = None
return req
return None
def request_restart(self) -> None:
self.interrupt()
self.server_command = True
def skip(self): def skip(self):
self.skipped = True self.skipped = True

View File

@ -1609,12 +1609,8 @@ def create_ui():
outputs=[] outputs=[]
) )
def request_restart():
shared.state.interrupt()
shared.state.need_restart = True
restart_gradio.click( restart_gradio.click(
fn=request_restart, fn=shared.state.request_restart,
_js='restart_reload', _js='restart_reload',
inputs=[], inputs=[],
outputs=[], outputs=[],

View File

@ -52,9 +52,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
shared.opts.disabled_extensions = disabled shared.opts.disabled_extensions = disabled
shared.opts.disable_all_extensions = disable_all shared.opts.disable_all_extensions = disable_all
shared.opts.save(shared.config_filename) shared.opts.save(shared.config_filename)
shared.state.request_restart()
shared.state.interrupt()
shared.state.need_restart = True
def save_config_state(name): def save_config_state(name):
@ -92,8 +90,7 @@ def restore_config_state(confirmed, config_state_name, restore_type):
if restore_type == "webui" or restore_type == "both": if restore_type == "webui" or restore_type == "both":
config_states.restore_webui_config(config_state) config_states.restore_webui_config(config_state)
shared.state.interrupt() shared.state.request_restart()
shared.state.need_restart = True
return "" return ""

View File

@ -8,7 +8,7 @@ import warnings
import json import json
from threading import Thread from threading import Thread
from fastapi import FastAPI from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from packaging import version from packaging import version
@ -241,7 +241,10 @@ def initialize():
print(f'Interrupted with signal {sig} in {frame}') print(f'Interrupted with signal {sig} in {frame}')
os._exit(0) os._exit(0)
signal.signal(signal.SIGINT, sigint_handler) if not os.environ.get("COVERAGE_RUN"):
# Don't install the immediate-quit handler when running under coverage,
# as then the coverage report won't be generated.
signal.signal(signal.SIGINT, sigint_handler)
def setup_middleware(app): def setup_middleware(app):
@ -262,19 +265,6 @@ def create_api(app):
return api return api
def wait_on_server(demo=None):
while 1:
time.sleep(0.5)
if shared.state.need_restart:
shared.state.need_restart = False
time.sleep(0.5)
demo.close()
time.sleep(0.5)
modules.script_callbacks.app_reload_callback()
break
def api_only(): def api_only():
initialize() initialize()
@ -287,6 +277,12 @@ def api_only():
print(f"Startup time: {startup_timer.summary()}.") print(f"Startup time: {startup_timer.summary()}.")
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
def stop_route(request):
shared.state.server_command = "stop"
return Response("Stopping.")
def webui(): def webui():
launch_api = cmd_opts.api launch_api = cmd_opts.api
initialize() initialize()
@ -335,6 +331,9 @@ def webui():
inbrowser=cmd_opts.autolaunch, inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True prevent_thread_lock=True
) )
if cmd_opts.add_stop_route:
app.add_route("/_stop", stop_route, methods=["POST"])
# after initial launch, disable --autolaunch for subsequent restarts # after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False cmd_opts.autolaunch = False
@ -366,8 +365,27 @@ def webui():
redirector.get("/") redirector.get("/")
gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}") gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
wait_on_server(shared.demo) try:
while True:
server_command = shared.state.wait_for_server_command(timeout=5)
if server_command:
if server_command in ("stop", "restart"):
break
else:
print(f"Unknown server command: {server_command}")
except KeyboardInterrupt:
print('Caught KeyboardInterrupt, stopping...')
server_command = "stop"
if server_command == "stop":
print("Stopping server...")
# If we catch a keyboard interrupt, we want to stop the server and exit.
shared.demo.close()
break
print('Restarting UI...') print('Restarting UI...')
shared.demo.close()
time.sleep(0.5)
modules.script_callbacks.app_reload_callback()
startup_timer.reset() startup_timer.reset()