From 875990a23213c63c19b8fdd3c87345f7a8ea2ceb Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 16 May 2023 20:58:35 +0300 Subject: [PATCH] Add option for /_stop route (for graceful shutdown) --- modules/cmd_args.py | 1 + webui.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index f4a4ab36..6144db5c 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -103,3 +103,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') +parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') diff --git a/webui.py b/webui.py index 39dec3ca..5172f049 100644 --- a/webui.py +++ b/webui.py @@ -8,7 +8,7 @@ import warnings import json from threading import Thread -from fastapi import FastAPI +from fastapi import FastAPI, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from packaging import version @@ -270,6 +270,12 @@ def api_only(): print(f"Startup time: {startup_timer.summary()}.") api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) + +def stop_route(request): + shared.state.server_command = "stop" + return Response("Stopping.") + + def webui(): launch_api = cmd_opts.api initialize() @@ -318,6 +324,8 @@ def webui(): inbrowser=cmd_opts.autolaunch, prevent_thread_lock=True ) + if cmd_opts.add_stop_route: + app.add_route("/_stop", stop_route, methods=["POST"]) # after initial launch, disable --autolaunch for subsequent restarts cmd_opts.autolaunch = False @@ -359,11 +367,12 @@ def webui(): else: print(f"Unknown server command: {server_command}") except KeyboardInterrupt: + print('Caught KeyboardInterrupt, stopping...') server_command = "stop" if server_command == "stop": + print("Stopping server...") # If we catch a keyboard interrupt, we want to stop the server and exit. - print('Caught KeyboardInterrupt, stopping...') shared.demo.close() break print('Restarting UI...')