Save/restore working webui/extension configs

This commit is contained in:
space-nuko 2023-03-29 16:46:03 -05:00
parent 22bcc7be42
commit ad5afcaae0
6 changed files with 243 additions and 9 deletions

View File

@ -47,3 +47,19 @@ function install_extension_from_index(button, url){
gradioApp().querySelector('#install_extension_button').click() gradioApp().querySelector('#install_extension_button').click()
} }
function config_state_confirm_restore(_, config_state_name, config_restore_type) {
if (config_state_name == "Current") {
return [false, config_state_name];
}
let restored = "";
if (config_restore_type == "extensions") {
restored = "all saved extension versions";
} else if (config_restore_type == "webui") {
restored = "the webui version";
} else {
restored = "the webui version and all saved extension versions";
}
let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".\n(A backup of the current state will be made.)");
return [confirmed, config_state_name, config_restore_type];
}

View File

@ -3,10 +3,11 @@ import sys
import traceback import traceback
import time import time
from datetime import datetime
import git import git
from modules import shared from modules import shared
from modules.paths_internal import extensions_dir, extensions_builtin_dir from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path
extensions = [] extensions = []
@ -31,12 +32,15 @@ class Extension:
self.status = '' self.status = ''
self.can_update = False self.can_update = False
self.is_builtin = is_builtin self.is_builtin = is_builtin
self.commit_hash = ''
self.commit_date = None
self.version = '' self.version = ''
self.branch = None
self.remote = None self.remote = None
self.have_info_from_repo = False self.have_info_from_repo = False
def read_info_from_repo(self): def read_info_from_repo(self):
if self.have_info_from_repo: if self.is_builtin or self.have_info_from_repo:
return return
self.have_info_from_repo = True self.have_info_from_repo = True
@ -56,10 +60,15 @@ class Extension:
self.status = 'unknown' self.status = 'unknown'
self.remote = next(repo.remote().urls, None) self.remote = next(repo.remote().urls, None)
head = repo.head.commit head = repo.head.commit
ts = time.asctime(time.gmtime(repo.head.commit.committed_date)) self.commit_date = repo.head.commit.committed_date
self.version = f'{head.hexsha[:8]} ({ts})' ts = time.asctime(time.gmtime(self.commit_date))
if repo.active_branch:
self.branch = repo.active_branch.name
self.commit_hash = head.hexsha
self.version = f'{self.commit_hash[:8]} ({ts})'
except Exception: except Exception as ex:
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
self.remote = None self.remote = None
def list_files(self, subdir, extension): def list_files(self, subdir, extension):
@ -88,12 +97,12 @@ class Extension:
self.can_update = False self.can_update = False
self.status = "latest" self.status = "latest"
def fetch_and_reset_hard(self): def fetch_and_reset_hard(self, commit='origin'):
repo = git.Repo(self.path) repo = git.Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`, # Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error. # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch(all=True) repo.git.fetch(all=True)
repo.git.reset('origin', hard=True) repo.git.reset(commit, hard=True)
def list_extensions(): def list_extensions():

View File

@ -20,3 +20,4 @@ data_path = cmd_opts_pre.data_dir
models_path = os.path.join(data_path, "models") models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions") extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin") extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
config_states_dir = os.path.join(script_path, "config_states")

View File

@ -424,6 +424,7 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
options_templates.update(options_section((None, "Hidden options"), { options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable these extensions"), "disabled_extensions": OptionInfo([], "Disable these extensions"),
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}), "disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
})) }))

View File

@ -2,6 +2,7 @@ import json
import os.path import os.path
import sys import sys
import time import time
from datetime import datetime
import traceback import traceback
import git import git
@ -11,7 +12,8 @@ import html
import shutil import shutil
import errno import errno
from modules import extensions, shared, paths from modules import extensions, shared, paths, config_states
from modules.paths_internal import config_states_dir
from modules.call_queue import wrap_gradio_gpu_call from modules.call_queue import wrap_gradio_gpu_call
available_extensions = {"extensions": []} available_extensions = {"extensions": []}
@ -30,6 +32,9 @@ def apply_and_restart(disable_list, update_list, disable_all):
update = json.loads(update_list) update = json.loads(update_list)
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}" assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
if update:
save_config_state("Backup (pre-update)")
update = set(update) update = set(update)
for ext in extensions.extensions: for ext in extensions.extensions:
@ -50,6 +55,48 @@ def apply_and_restart(disable_list, update_list, disable_all):
shared.state.need_restart = True shared.state.need_restart = True
def save_config_state(name):
current_config_state = config_states.get_config()
if not name:
name = "Config"
current_config_state["name"] = name
filename = os.path.join(config_states_dir, datetime.now().strftime("%Y_%m_%d-%H_%M_%S") + ".json")
print(f"Saving backup of webui/extension state to {filename}.")
with open(filename, "w", encoding="utf-8") as f:
json.dump(current_config_state, f)
config_states.list_config_states()
new_value = next(iter(config_states.all_config_states.keys()), "Current")
new_choices = ["Current"] + list(config_states.all_config_states.keys())
return gr.Dropdown.update(value=new_value, choices=new_choices), f"<span>Saved current webui/extension state to '{filename}'</span>"
def restore_config_state(confirmed, config_state_name, restore_type):
if config_state_name == "Current":
return "<span>Select a config to restore from.</span>"
if not confirmed:
return "<span>Cancelled.</span>"
check_access()
save_config_state("Backup (pre-restore)")
config_state = config_states.all_config_states[config_state_name]
print(f"Restoring webui state from backup: {restore_type}")
if restore_type == "extensions" or restore_type == "both":
shared.opts.restore_config_state_file = config_state["filename"]
shared.opts.save(shared.config_filename)
if restore_type == "webui" or restore_type == "both":
config_states.restore_webui_config(config_state)
shared.state.interrupt()
shared.state.need_restart = True
return ""
def check_updates(id_task, disable_list): def check_updates(id_task, disable_list):
check_access() check_access()
@ -121,6 +168,117 @@ def extension_table():
return code return code
def update_config_states_table(state_name):
if state_name == "Current":
config_state = config_states.get_config()
else:
config_state = config_states.all_config_states[state_name]
config_name = config_state.get("name", "Config")
created_date = time.asctime(time.gmtime(config_state["created_at"]))
code = f"""<!-- {time.time()} -->"""
webui_remote = config_state["webui"]["remote"] or ""
webui_branch = config_state["webui"]["branch"]
webui_commit_hash = config_state["webui"]["commit_hash"]
if webui_commit_hash:
webui_commit_hash = webui_commit_hash[:8]
else:
webui_commit_hash = "<unknown>"
webui_commit_date = config_state["webui"]["commit_date"]
if webui_commit_date:
webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
else:
webui_commit_date = "<unknown>"
code += f"""<h2>Config Backup: {config_name}</h2>
<span>Created at: {created_date}</span>"""
code += f"""<h2>WebUI State</h2>
<table id="config_state_webui">
<thead>
<tr>
<th>URL</th>
<th>Branch</th>
<th>Commit</th>
<th>Date</th>
</tr>
</thead>
<tbody>
<tr>
<td>{webui_remote}</td>
<td>{webui_branch}</td>
<td>{webui_commit_hash}</td>
<td>{webui_commit_date}</td>
</tr>
</tbody>
</table>
"""
code += """<h2>Extension State</h2>
<table id="config_state_extensions">
<thead>
<tr>
<th>Extension</th>
<th>URL</th>
<th>Branch</th>
<th>Commit</th>
<th>Date</th>
</tr>
</thead>
<tbody>
"""
ext_map = {ext.name: ext for ext in extensions.extensions}
for ext_name, ext_conf in config_state["extensions"].items():
ext_remote = ext_conf["remote"] or ""
ext_branch = ext_conf["branch"] or "<unknown>"
ext_enabled = ext_conf["enabled"]
ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
ext_commit_date = ext_conf["commit_date"]
if ext_commit_date:
ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
else:
ext_commit_date = "<unknown>"
remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
style_enabled = ""
style_remote = ""
style_branch = ""
style_commit = ""
if ext_name in ext_map:
current_ext = ext_map[ext_name]
current_ext.read_info_from_repo()
if current_ext.enabled != ext_enabled:
style_enabled = ' style="color: var(--primary-400)"'
if current_ext.remote != ext_remote:
style_remote = ' style="color: var(--primary-400)"'
if current_ext.branch != ext_branch:
style_branch = ' style="color: var(--primary-400)"'
if current_ext.commit_hash != ext_commit_hash:
style_commit = ' style="color: var(--primary-400)"'
code += f"""
<tr>
<td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
<td><label{style_remote}>{remote}</label></td>
<td><label{style_branch}>{ext_branch}</label></td>
<td><label{style_commit}>{ext_commit_hash[:8]}</label></td>
<td><label{style_commit}>{ext_commit_date}</label></td>
</tr>
"""
code += """
</tbody>
</table>
"""
return code
def normalize_git_url(url): def normalize_git_url(url):
if url is None: if url is None:
return "" return ""
@ -292,6 +450,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
def create_ui(): def create_ui():
import modules.ui import modules.ui
config_states.list_config_states()
with gr.Blocks(analytics_enabled=False) as ui: with gr.Blocks(analytics_enabled=False) as ui:
with gr.Tabs(elem_id="tabs_extensions") as tabs: with gr.Tabs(elem_id="tabs_extensions") as tabs:
with gr.TabItem("Installed"): with gr.TabItem("Installed"):
@ -386,4 +546,28 @@ def create_ui():
outputs=[extensions_table, install_result], outputs=[extensions_table, install_result],
) )
with gr.TabItem("Backup/Restore"):
with gr.Row(elem_id="extensions_backup_top_row"):
config_states_list = gr.Dropdown(label="Saved Configs", elem_id="extension_backup_saved_configs", value="Current", choices=["Current"] + list(config_states.all_config_states.keys()))
modules.ui.create_refresh_button(config_states_list, config_states.list_config_states, lambda: {"choices": ["Current"] + list(config_states.all_config_states.keys())}, "refresh_config_states")
config_restore_type = gr.Radio(label="State to restore", choices=["extensions", "webui", "both"], value="extensions", elem_id="extension_backup_restore_type")
config_restore_button = gr.Button(value="Restore Selected Config", variant="primary", elem_id="extension_backup_restore")
with gr.Row(elem_id="extensions_backup_top_row2"):
config_save_name = gr.Textbox("", placeholder="Config Name", show_label=False)
config_save_button = gr.Button(value="Save Current Config")
config_states_info = gr.HTML("")
config_states_table = gr.HTML(lambda: update_config_states_table("Current"))
config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info])
dummy_component = gr.Label(visible=False)
config_restore_button.click(fn=restore_config_state, _js="config_state_confirm_restore", inputs=[dummy_component, config_states_list, config_restore_type], outputs=[config_states_info])
config_states_list.change(
fn=update_config_states_table,
inputs=[config_states_list],
outputs=[config_states_table],
)
return ui return ui

View File

@ -5,6 +5,7 @@ import importlib
import signal import signal
import re import re
import warnings import warnings
import json
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
@ -37,7 +38,7 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__ torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
import modules.codeformer_model as codeformer import modules.codeformer_model as codeformer
import modules.face_restoration import modules.face_restoration
import modules.gfpgan_model as gfpgan import modules.gfpgan_model as gfpgan
@ -105,6 +106,17 @@ def initialize():
localization.list_localizations(cmd_opts.localizations_dir) localization.list_localizations(cmd_opts.localizations_dir)
startup_timer.record("list extensions") startup_timer.record("list extensions")
config_state_file = shared.opts.restore_config_state_file
shared.opts.restore_config_state_file = ""
shared.opts.save(shared.config_filename)
if os.path.isfile(config_state_file):
print(f"*** About to restore extension state from file: {config_state_file}")
with open(config_state_file, "r", encoding="utf-8") as f:
config_state = json.load(f)
config_states.restore_extension_state(config_state)
startup_timer.record("restore extension config")
if cmd_opts.ui_debug_mode: if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts() modules.scripts.load_scripts()
@ -301,6 +313,17 @@ def webui():
extensions.list_extensions() extensions.list_extensions()
startup_timer.record("list extensions") startup_timer.record("list extensions")
config_state_file = shared.opts.restore_config_state_file
shared.opts.restore_config_state_file = ""
shared.opts.save(shared.config_filename)
if os.path.isfile(config_state_file):
print(f"*** About to restore extension state from file: {config_state_file}")
with open(config_state_file, "r", encoding="utf-8") as f:
config_state = json.load(f)
config_states.restore_extension_state(config_state)
startup_timer.record("restore extension config")
localization.list_localizations(cmd_opts.localizations_dir) localization.list_localizations(cmd_opts.localizations_dir)
modelloader.forbid_loaded_nonbuiltin_upscalers() modelloader.forbid_loaded_nonbuiltin_upscalers()