diff --git a/.gitignore b/.gitignore
index 3b48ba9a..7328401f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -33,3 +33,4 @@ notification.mp3
/test/stdout.txt
/test/stderr.txt
/cache.json*
+/config_states/
diff --git a/javascript/extensions.js b/javascript/extensions.js
index 72924a28..3c2f995a 100644
--- a/javascript/extensions.js
+++ b/javascript/extensions.js
@@ -47,3 +47,25 @@ function install_extension_from_index(button, url){
gradioApp().querySelector('#install_extension_button').click()
}
+
+function config_state_confirm_restore(_, config_state_name, config_restore_type) {
+ if (config_state_name == "Current") {
+ return [false, config_state_name, config_restore_type];
+ }
+ let restored = "";
+ if (config_restore_type == "extensions") {
+ restored = "all saved extension versions";
+ } else if (config_restore_type == "webui") {
+ restored = "the webui version";
+ } else {
+ restored = "the webui version and all saved extension versions";
+ }
+ let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
+ if (confirmed) {
+ restart_reload();
+ gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
+ x.innerHTML = "Loading..."
+ })
+ }
+ return [confirmed, config_state_name, config_restore_type];
+}
diff --git a/modules/config_states.py b/modules/config_states.py
new file mode 100644
index 00000000..2ea00929
--- /dev/null
+++ b/modules/config_states.py
@@ -0,0 +1,200 @@
+"""
+Supports saving and restoring webui and extensions from a known working set of commits
+"""
+
+import os
+import sys
+import traceback
+import json
+import time
+import tqdm
+
+from datetime import datetime
+from collections import OrderedDict
+import git
+
+from modules import shared, extensions
+from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir
+
+
+all_config_states = OrderedDict()
+
+
+def list_config_states():
+ global all_config_states
+
+ all_config_states.clear()
+ os.makedirs(config_states_dir, exist_ok=True)
+
+ config_states = []
+ for filename in os.listdir(config_states_dir):
+ if filename.endswith(".json"):
+ path = os.path.join(config_states_dir, filename)
+ with open(path, "r", encoding="utf-8") as f:
+ j = json.load(f)
+ j["filepath"] = path
+ config_states.append(j)
+
+ config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True))
+
+ for cs in config_states:
+ timestamp = time.asctime(time.gmtime(cs["created_at"]))
+ name = cs.get("name", "Config")
+ full_name = f"{name}: {timestamp}"
+ all_config_states[full_name] = cs
+
+ return all_config_states
+
+
+def get_webui_config():
+ webui_repo = None
+
+ try:
+ if os.path.exists(os.path.join(script_path, ".git")):
+ webui_repo = git.Repo(script_path)
+ except Exception:
+ print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ webui_remote = None
+ webui_commit_hash = None
+ webui_commit_date = None
+ webui_branch = None
+ if webui_repo and not webui_repo.bare:
+ try:
+ webui_remote = next(webui_repo.remote().urls, None)
+ head = webui_repo.head.commit
+ webui_commit_date = webui_repo.head.commit.committed_date
+ webui_commit_hash = head.hexsha
+ webui_branch = webui_repo.active_branch.name
+
+ except Exception:
+ webui_remote = None
+
+ return {
+ "remote": webui_remote,
+ "commit_hash": webui_commit_hash,
+ "commit_date": webui_commit_date,
+ "branch": webui_branch,
+ }
+
+
+def get_extension_config():
+ ext_config = {}
+
+ for ext in extensions.extensions:
+ entry = {
+ "name": ext.name,
+ "path": ext.path,
+ "enabled": ext.enabled,
+ "is_builtin": ext.is_builtin,
+ "remote": ext.remote,
+ "commit_hash": ext.commit_hash,
+ "commit_date": ext.commit_date,
+ "branch": ext.branch,
+ "have_info_from_repo": ext.have_info_from_repo
+ }
+
+ ext_config[ext.name] = entry
+
+ return ext_config
+
+
+def get_config():
+ creation_time = datetime.now().timestamp()
+ webui_config = get_webui_config()
+ ext_config = get_extension_config()
+
+ return {
+ "created_at": creation_time,
+ "webui": webui_config,
+ "extensions": ext_config
+ }
+
+
+def restore_webui_config(config):
+ print("* Restoring webui state...")
+
+ if "webui" not in config:
+ print("Error: No webui data saved to config")
+ return
+
+ webui_config = config["webui"]
+
+ if "commit_hash" not in webui_config:
+ print("Error: No commit saved to webui config")
+ return
+
+ webui_commit_hash = webui_config.get("commit_hash", None)
+ webui_repo = None
+
+ try:
+ if os.path.exists(os.path.join(script_path, ".git")):
+ webui_repo = git.Repo(script_path)
+ except Exception:
+ print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return
+
+ try:
+ webui_repo.git.fetch(all=True)
+ webui_repo.git.reset(webui_commit_hash, hard=True)
+ print(f"* Restored webui to commit {webui_commit_hash}.")
+ except Exception:
+ print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+
+def restore_extension_config(config):
+ print("* Restoring extension state...")
+
+ if "extensions" not in config:
+ print("Error: No extension data saved to config")
+ return
+
+ ext_config = config["extensions"]
+
+ results = []
+ disabled = []
+
+ for ext in tqdm.tqdm(extensions.extensions):
+ if ext.is_builtin:
+ continue
+
+ ext.read_info_from_repo()
+ current_commit = ext.commit_hash
+
+ if ext.name not in ext_config:
+ ext.disabled = True
+ disabled.append(ext.name)
+ results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
+ continue
+
+ entry = ext_config[ext.name]
+
+ if "commit_hash" in entry and entry["commit_hash"]:
+ try:
+ ext.fetch_and_reset_hard(entry["commit_hash"])
+ ext.read_info_from_repo()
+ if current_commit != entry["commit_hash"]:
+ results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
+ except Exception as ex:
+ results.append((ext, current_commit[:8], False, ex))
+ else:
+ results.append((ext, current_commit[:8], False, "No commit hash found in config"))
+
+ if not entry.get("enabled", False):
+ ext.disabled = True
+ disabled.append(ext.name)
+ else:
+ ext.disabled = False
+
+ shared.opts.disabled_extensions = disabled
+ shared.opts.save(shared.config_filename)
+
+ print("* Finished restoring extensions. Results:")
+ for ext, prev_commit, success, result in results:
+ if success:
+ print(f" + {ext.name}: {prev_commit} -> {result}")
+ else:
+ print(f" ! {ext.name}: FAILURE ({result})")
diff --git a/modules/extensions.py b/modules/extensions.py
index 3a7a0372..34d9d654 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -3,10 +3,11 @@ import sys
import traceback
import time
+from datetime import datetime
import git
from modules import shared
-from modules.paths_internal import extensions_dir, extensions_builtin_dir
+from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path
extensions = []
@@ -31,12 +32,15 @@ class Extension:
self.status = ''
self.can_update = False
self.is_builtin = is_builtin
+ self.commit_hash = ''
+ self.commit_date = None
self.version = ''
+ self.branch = None
self.remote = None
self.have_info_from_repo = False
def read_info_from_repo(self):
- if self.have_info_from_repo:
+ if self.is_builtin or self.have_info_from_repo:
return
self.have_info_from_repo = True
@@ -56,10 +60,15 @@ class Extension:
self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
head = repo.head.commit
- ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
- self.version = f'{head.hexsha[:8]} ({ts})'
+ self.commit_date = repo.head.commit.committed_date
+ ts = time.asctime(time.gmtime(self.commit_date))
+ if repo.active_branch:
+ self.branch = repo.active_branch.name
+ self.commit_hash = head.hexsha
+ self.version = f'{self.commit_hash[:8]} ({ts})'
- except Exception:
+ except Exception as ex:
+ print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
self.remote = None
def list_files(self, subdir, extension):
@@ -82,18 +91,30 @@ class Extension:
for fetch in repo.remote().fetch(dry_run=True):
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
- self.status = "behind"
+ self.status = "new commits"
return
+ try:
+ origin = repo.rev_parse('origin')
+ if repo.head.commit != origin:
+ self.can_update = True
+ self.status = "behind HEAD"
+ return
+ except Exception:
+ self.can_update = False
+ self.status = "unknown (remote error)"
+ return
+
self.can_update = False
self.status = "latest"
- def fetch_and_reset_hard(self):
+ def fetch_and_reset_hard(self, commit='origin'):
repo = git.Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch(all=True)
- repo.git.reset('origin', hard=True)
+ repo.git.reset(commit, hard=True)
+ self.have_info_from_repo = False
def list_extensions():
diff --git a/modules/paths_internal.py b/modules/paths_internal.py
index 926ec3bb..6765bafe 100644
--- a/modules/paths_internal.py
+++ b/modules/paths_internal.py
@@ -20,3 +20,4 @@ data_path = cmd_opts_pre.data_dir
models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
+config_states_dir = os.path.join(script_path, "config_states")
diff --git a/modules/shared.py b/modules/shared.py
index 6a14dcd0..6a2b3c2b 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -449,6 +449,7 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable these extensions"),
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
+ "restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index e90bedc8..79ff2389 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -2,6 +2,7 @@ import json
import os.path
import sys
import time
+from datetime import datetime
import traceback
import git
@@ -11,10 +12,12 @@ import html
import shutil
import errno
-from modules import extensions, shared, paths
+from modules import extensions, shared, paths, config_states
+from modules.paths_internal import config_states_dir
from modules.call_queue import wrap_gradio_gpu_call
available_extensions = {"extensions": []}
+STYLE_PRIMARY = ' style="color: var(--primary-400)"'
def check_access():
@@ -30,6 +33,9 @@ def apply_and_restart(disable_list, update_list, disable_all):
update = json.loads(update_list)
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
+ if update:
+ save_config_state("Backup (pre-update)")
+
update = set(update)
for ext in extensions.extensions:
@@ -50,6 +56,46 @@ def apply_and_restart(disable_list, update_list, disable_all):
shared.state.need_restart = True
+def save_config_state(name):
+ current_config_state = config_states.get_config()
+ if not name:
+ name = "Config"
+ current_config_state["name"] = name
+ filename = os.path.join(config_states_dir, datetime.now().strftime("%Y_%m_%d-%H_%M_%S") + "_" + name + ".json")
+ print(f"Saving backup of webui/extension state to {filename}.")
+ with open(filename, "w", encoding="utf-8") as f:
+ json.dump(current_config_state, f)
+ config_states.list_config_states()
+ new_value = next(iter(config_states.all_config_states.keys()), "Current")
+ new_choices = ["Current"] + list(config_states.all_config_states.keys())
+ return gr.Dropdown.update(value=new_value, choices=new_choices), f"Saved current webui/extension state to \"{filename}\""
+
+
+def restore_config_state(confirmed, config_state_name, restore_type):
+ if config_state_name == "Current":
+ return "Select a config to restore from."
+ if not confirmed:
+ return "Cancelled."
+
+ check_access()
+
+ config_state = config_states.all_config_states[config_state_name]
+
+ print(f"*** Restoring webui state from backup: {restore_type} ***")
+
+ if restore_type == "extensions" or restore_type == "both":
+ shared.opts.restore_config_state_file = config_state["filepath"]
+ shared.opts.save(shared.config_filename)
+
+ if restore_type == "webui" or restore_type == "both":
+ config_states.restore_webui_config(config_state)
+
+ shared.state.interrupt()
+ shared.state.need_restart = True
+
+ return ""
+
+
def check_updates(id_task, disable_list):
check_access()
@@ -76,6 +122,16 @@ def check_updates(id_task, disable_list):
return extension_table(), ""
+def make_commit_link(commit_hash, remote, text=None):
+ if text is None:
+ text = commit_hash[:8]
+ if remote.startswith("https://github.com/"):
+ href = os.path.join(remote, "commit", commit_hash)
+ return f'{text}'
+ else:
+ return text
+
+
def extension_table():
code = f"""
@@ -102,13 +158,17 @@ def extension_table():
style = ""
if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all":
- style = ' style="color: var(--primary-400)"'
+ style = STYLE_PRIMARY
+
+ version_link = ext.version
+ if ext.commit_hash and ext.remote:
+ version_link = make_commit_link(ext.commit_hash, ext.remote, ext.version)
code += f"""
|
{remote} |
- {ext.version} |
+ {version_link} |
{ext_status} |
"""
@@ -121,6 +181,133 @@ def extension_table():
return code
+def update_config_states_table(state_name):
+ if state_name == "Current":
+ config_state = config_states.get_config()
+ else:
+ config_state = config_states.all_config_states[state_name]
+
+ config_name = config_state.get("name", "Config")
+ created_date = time.asctime(time.gmtime(config_state["created_at"]))
+ filepath = config_state.get("filepath", "")
+
+ code = f""""""
+
+ webui_remote = config_state["webui"]["remote"] or ""
+ webui_branch = config_state["webui"]["branch"]
+ webui_commit_hash = config_state["webui"]["commit_hash"] or ""
+ webui_commit_date = config_state["webui"]["commit_date"]
+ if webui_commit_date:
+ webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
+ else:
+ webui_commit_date = ""
+
+ remote = f"""{html.escape(webui_remote or '')}"""
+ commit_link = make_commit_link(webui_commit_hash, webui_remote)
+ date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
+
+ current_webui = config_states.get_webui_config()
+
+ style_remote = ""
+ style_branch = ""
+ style_commit = ""
+ if current_webui["remote"] != webui_remote:
+ style_remote = STYLE_PRIMARY
+ if current_webui["branch"] != webui_branch:
+ style_branch = STYLE_PRIMARY
+ if current_webui["commit_hash"] != webui_commit_hash:
+ style_commit = STYLE_PRIMARY
+
+ code += f"""Config Backup: {config_name}
+ Filepath: {filepath}
+ Created at: {created_date}
"""
+
+ code += f"""WebUI State
+
+
+
+ URL |
+ Branch |
+ Commit |
+ Date |
+
+
+
+
+ |
+ |
+ |
+ |
+
+
+
+ """
+
+ code += """Extension State
+
+ """
+
+ return code
+
+
def normalize_git_url(url):
if url is None:
return ""
@@ -299,6 +486,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
def create_ui():
import modules.ui
+ config_states.list_config_states()
+
with gr.Blocks(analytics_enabled=False) as ui:
with gr.Tabs(elem_id="tabs_extensions") as tabs:
with gr.TabItem("Installed"):
@@ -394,4 +583,28 @@ def create_ui():
outputs=[extensions_table, install_result],
)
+ with gr.TabItem("Backup/Restore"):
+ with gr.Row(elem_id="extensions_backup_top_row"):
+ config_states_list = gr.Dropdown(label="Saved Configs", elem_id="extension_backup_saved_configs", value="Current", choices=["Current"] + list(config_states.all_config_states.keys()))
+ modules.ui.create_refresh_button(config_states_list, config_states.list_config_states, lambda: {"choices": ["Current"] + list(config_states.all_config_states.keys())}, "refresh_config_states")
+ config_restore_type = gr.Radio(label="State to restore", choices=["extensions", "webui", "both"], value="extensions", elem_id="extension_backup_restore_type")
+ config_restore_button = gr.Button(value="Restore Selected Config", variant="primary", elem_id="extension_backup_restore")
+ with gr.Row(elem_id="extensions_backup_top_row2"):
+ config_save_name = gr.Textbox("", placeholder="Config Name", show_label=False)
+ config_save_button = gr.Button(value="Save Current Config")
+
+ config_states_info = gr.HTML("")
+ config_states_table = gr.HTML(lambda: update_config_states_table("Current"))
+
+ config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info])
+
+ dummy_component = gr.Label(visible=False)
+ config_restore_button.click(fn=restore_config_state, _js="config_state_confirm_restore", inputs=[dummy_component, config_states_list, config_restore_type], outputs=[config_states_info])
+
+ config_states_list.change(
+ fn=update_config_states_table,
+ inputs=[config_states_list],
+ outputs=[config_states_table],
+ )
+
return ui
diff --git a/webui.py b/webui.py
index 95623c6f..ae3285c6 100644
--- a/webui.py
+++ b/webui.py
@@ -5,6 +5,7 @@ import importlib
import signal
import re
import warnings
+import json
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
@@ -40,7 +41,7 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
-from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
+from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
import modules.codeformer_model as codeformer
import modules.face_restoration
import modules.gfpgan_model as gfpgan
@@ -150,6 +151,19 @@ def initialize():
localization.list_localizations(cmd_opts.localizations_dir)
startup_timer.record("list extensions")
+ config_state_file = shared.opts.restore_config_state_file
+ shared.opts.restore_config_state_file = ""
+ shared.opts.save(shared.config_filename)
+
+ if os.path.isfile(config_state_file):
+ print(f"*** About to restore extension state from file: {config_state_file}")
+ with open(config_state_file, "r", encoding="utf-8") as f:
+ config_state = json.load(f)
+ config_states.restore_extension_config(config_state)
+ startup_timer.record("restore extension config")
+ elif config_state_file:
+ print(f"!!! Config state backup not found: {config_state_file}")
+
if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts()
@@ -344,6 +358,19 @@ def webui():
extensions.list_extensions()
startup_timer.record("list extensions")
+ config_state_file = shared.opts.restore_config_state_file
+ shared.opts.restore_config_state_file = ""
+ shared.opts.save(shared.config_filename)
+
+ if os.path.isfile(config_state_file):
+ print(f"*** About to restore extension state from file: {config_state_file}")
+ with open(config_state_file, "r", encoding="utf-8") as f:
+ config_state = json.load(f)
+ config_states.restore_extension_config(config_state)
+ startup_timer.record("restore extension config")
+ elif config_state_file:
+ print(f"!!! Config state backup not found: {config_state_file}")
+
localization.list_localizations(cmd_opts.localizations_dir)
modelloader.forbid_loaded_nonbuiltin_upscalers()