Merge pull request #10458 from akx/graceful-stop
Graceful server stopping
This commit is contained in:
commit
f6c06e3ed2
@ -103,3 +103,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch
|
|||||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||||
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
||||||
|
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
|
||||||
|
@ -2,6 +2,7 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -110,8 +111,47 @@ class State:
|
|||||||
id_live_preview = 0
|
id_live_preview = 0
|
||||||
textinfo = None
|
textinfo = None
|
||||||
time_start = None
|
time_start = None
|
||||||
need_restart = False
|
|
||||||
server_start = None
|
server_start = None
|
||||||
|
_server_command_signal = threading.Event()
|
||||||
|
_server_command: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def need_restart(self) -> bool:
|
||||||
|
# Compatibility getter for need_restart.
|
||||||
|
return self.server_command == "restart"
|
||||||
|
|
||||||
|
@need_restart.setter
|
||||||
|
def need_restart(self, value: bool) -> None:
|
||||||
|
# Compatibility setter for need_restart.
|
||||||
|
if value:
|
||||||
|
self.server_command = "restart"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def server_command(self):
|
||||||
|
return self._server_command
|
||||||
|
|
||||||
|
@server_command.setter
|
||||||
|
def server_command(self, value: str | None) -> None:
|
||||||
|
"""
|
||||||
|
Set the server command to `value` and signal that it's been set.
|
||||||
|
"""
|
||||||
|
self._server_command = value
|
||||||
|
self._server_command_signal.set()
|
||||||
|
|
||||||
|
def wait_for_server_command(self, timeout: float | None = None) -> str | None:
|
||||||
|
"""
|
||||||
|
Wait for server command to get set; return and clear the value and signal.
|
||||||
|
"""
|
||||||
|
if self._server_command_signal.wait(timeout):
|
||||||
|
self._server_command_signal.clear()
|
||||||
|
req = self._server_command
|
||||||
|
self._server_command = None
|
||||||
|
return req
|
||||||
|
return None
|
||||||
|
|
||||||
|
def request_restart(self) -> None:
|
||||||
|
self.interrupt()
|
||||||
|
self.server_command = True
|
||||||
|
|
||||||
def skip(self):
|
def skip(self):
|
||||||
self.skipped = True
|
self.skipped = True
|
||||||
|
@ -1609,12 +1609,8 @@ def create_ui():
|
|||||||
outputs=[]
|
outputs=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
def request_restart():
|
|
||||||
shared.state.interrupt()
|
|
||||||
shared.state.need_restart = True
|
|
||||||
|
|
||||||
restart_gradio.click(
|
restart_gradio.click(
|
||||||
fn=request_restart,
|
fn=shared.state.request_restart,
|
||||||
_js='restart_reload',
|
_js='restart_reload',
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[],
|
outputs=[],
|
||||||
|
@ -52,9 +52,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
|
|||||||
shared.opts.disabled_extensions = disabled
|
shared.opts.disabled_extensions = disabled
|
||||||
shared.opts.disable_all_extensions = disable_all
|
shared.opts.disable_all_extensions = disable_all
|
||||||
shared.opts.save(shared.config_filename)
|
shared.opts.save(shared.config_filename)
|
||||||
|
shared.state.request_restart()
|
||||||
shared.state.interrupt()
|
|
||||||
shared.state.need_restart = True
|
|
||||||
|
|
||||||
|
|
||||||
def save_config_state(name):
|
def save_config_state(name):
|
||||||
@ -92,8 +90,7 @@ def restore_config_state(confirmed, config_state_name, restore_type):
|
|||||||
if restore_type == "webui" or restore_type == "both":
|
if restore_type == "webui" or restore_type == "both":
|
||||||
config_states.restore_webui_config(config_state)
|
config_states.restore_webui_config(config_state)
|
||||||
|
|
||||||
shared.state.interrupt()
|
shared.state.request_restart()
|
||||||
shared.state.need_restart = True
|
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
50
webui.py
50
webui.py
@ -8,7 +8,7 @@ import warnings
|
|||||||
import json
|
import json
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, Response
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@ -241,7 +241,10 @@ def initialize():
|
|||||||
print(f'Interrupted with signal {sig} in {frame}')
|
print(f'Interrupted with signal {sig} in {frame}')
|
||||||
os._exit(0)
|
os._exit(0)
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, sigint_handler)
|
if not os.environ.get("COVERAGE_RUN"):
|
||||||
|
# Don't install the immediate-quit handler when running under coverage,
|
||||||
|
# as then the coverage report won't be generated.
|
||||||
|
signal.signal(signal.SIGINT, sigint_handler)
|
||||||
|
|
||||||
|
|
||||||
def setup_middleware(app):
|
def setup_middleware(app):
|
||||||
@ -262,19 +265,6 @@ def create_api(app):
|
|||||||
return api
|
return api
|
||||||
|
|
||||||
|
|
||||||
def wait_on_server(demo=None):
|
|
||||||
while 1:
|
|
||||||
time.sleep(0.5)
|
|
||||||
if shared.state.need_restart:
|
|
||||||
shared.state.need_restart = False
|
|
||||||
time.sleep(0.5)
|
|
||||||
demo.close()
|
|
||||||
time.sleep(0.5)
|
|
||||||
|
|
||||||
modules.script_callbacks.app_reload_callback()
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def api_only():
|
def api_only():
|
||||||
initialize()
|
initialize()
|
||||||
|
|
||||||
@ -287,6 +277,12 @@ def api_only():
|
|||||||
print(f"Startup time: {startup_timer.summary()}.")
|
print(f"Startup time: {startup_timer.summary()}.")
|
||||||
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
||||||
|
|
||||||
|
|
||||||
|
def stop_route(request):
|
||||||
|
shared.state.server_command = "stop"
|
||||||
|
return Response("Stopping.")
|
||||||
|
|
||||||
|
|
||||||
def webui():
|
def webui():
|
||||||
launch_api = cmd_opts.api
|
launch_api = cmd_opts.api
|
||||||
initialize()
|
initialize()
|
||||||
@ -335,6 +331,9 @@ def webui():
|
|||||||
inbrowser=cmd_opts.autolaunch,
|
inbrowser=cmd_opts.autolaunch,
|
||||||
prevent_thread_lock=True
|
prevent_thread_lock=True
|
||||||
)
|
)
|
||||||
|
if cmd_opts.add_stop_route:
|
||||||
|
app.add_route("/_stop", stop_route, methods=["POST"])
|
||||||
|
|
||||||
# after initial launch, disable --autolaunch for subsequent restarts
|
# after initial launch, disable --autolaunch for subsequent restarts
|
||||||
cmd_opts.autolaunch = False
|
cmd_opts.autolaunch = False
|
||||||
|
|
||||||
@ -366,8 +365,27 @@ def webui():
|
|||||||
redirector.get("/")
|
redirector.get("/")
|
||||||
gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
|
gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
|
||||||
|
|
||||||
wait_on_server(shared.demo)
|
try:
|
||||||
|
while True:
|
||||||
|
server_command = shared.state.wait_for_server_command(timeout=5)
|
||||||
|
if server_command:
|
||||||
|
if server_command in ("stop", "restart"):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"Unknown server command: {server_command}")
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print('Caught KeyboardInterrupt, stopping...')
|
||||||
|
server_command = "stop"
|
||||||
|
|
||||||
|
if server_command == "stop":
|
||||||
|
print("Stopping server...")
|
||||||
|
# If we catch a keyboard interrupt, we want to stop the server and exit.
|
||||||
|
shared.demo.close()
|
||||||
|
break
|
||||||
print('Restarting UI...')
|
print('Restarting UI...')
|
||||||
|
shared.demo.close()
|
||||||
|
time.sleep(0.5)
|
||||||
|
modules.script_callbacks.app_reload_callback()
|
||||||
|
|
||||||
startup_timer.reset()
|
startup_timer.reset()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user