From bbce167305091b34795284cabca7ab2fd56469b6 Mon Sep 17 00:00:00 2001 From: lenankamp <31517075+lenankamp@users.noreply.github.com> Date: Tue, 16 May 2023 14:37:45 -0400 Subject: [PATCH 001/178] Recursive batch img2img.py Searches sub directories and performs img2img batch processing, also limits inputs to jpg, webp, and png. Then saves to putput directory with relative paths. --- modules/img2img.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 9fc3a698..ad5f2e73 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -20,7 +20,13 @@ import modules.scripts def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processing.fix_seed(p) - images = shared.listfiles(input_dir) +# recursive batch, as written limits potential inputs to common image formats, may e better to just check if isfile for general use +images = [] + for root, directories, files in os.walk(input_dir): + for filename in files: + filepath = os.path.join(root, filename) + if filepath.endswith(".jpg") or filepath.endswith(".jpeg") or filepath.endswith(".png") or filepath.endswith(".webp"): + images.append(filepath) is_inpaint_batch = False if inpaint_mask_dir: @@ -70,16 +76,17 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): for n, processed_image in enumerate(proc.images): filename = os.path.basename(image) + relpath = os.path.dirname(os.path.relpath(image, input_dir)) if n > 0: left, right = os.path.splitext(filename) filename = f"{left}-{n}{right}" if not save_normally: - os.makedirs(output_dir, exist_ok=True) + os.makedirs(os.path.join(output_dir, relpath), exist_ok=True) if processed_image.mode == 'RGBA': processed_image = processed_image.convert("RGB") - processed_image.save(os.path.join(output_dir, filename)) + processed_image.save(os.path.join(output_dir, relpath, filename)) def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): From ff6acd35d0807a4e0c3ee86cdb1520a4a3a11cdd Mon Sep 17 00:00:00 2001 From: lenankamp <31517075+lenankamp@users.noreply.github.com> Date: Fri, 19 May 2023 03:20:19 -0400 Subject: [PATCH 002/178] Update img2img.py Hopefully corrected the white space issue --- modules/img2img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index ad5f2e73..d1872bed 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -20,14 +20,14 @@ import modules.scripts def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processing.fix_seed(p) -# recursive batch, as written limits potential inputs to common image formats, may e better to just check if isfile for general use -images = [] + images = [] for root, directories, files in os.walk(input_dir): for filename in files: filepath = os.path.join(root, filename) if filepath.endswith(".jpg") or filepath.endswith(".jpeg") or filepath.endswith(".png") or filepath.endswith(".webp"): images.append(filepath) + is_inpaint_batch = False if inpaint_mask_dir: inpaint_masks = shared.listfiles(inpaint_mask_dir) From 468056958b63cb869d627746b5b5a1c629fd7548 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Mon, 22 May 2023 20:46:25 -0600 Subject: [PATCH 003/178] Add reorder hotkeys Shifts selected items with ctrl+left/right --- javascript/edit-order.js | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 javascript/edit-order.js diff --git a/javascript/edit-order.js b/javascript/edit-order.js new file mode 100644 index 00000000..e924d419 --- /dev/null +++ b/javascript/edit-order.js @@ -0,0 +1,36 @@ +function keyupEditOrder(event){ + let target = event.originalTarget || event.composedPath()[0]; + if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return; + if (!event.metaKey && !event.ctrlKey) return; + + let isLeft = event.key == "ArrowLeft" + let isRight = event.key == "ArrowRight" + if (!isLeft && !isRight) return; + + let selectionStart = target.selectionStart; + let selectionEnd = target.selectionEnd; + let text = target.value; + let items = text.split(",") + let indexStart = (text.slice(0, selectionStart).match(/,/g) || []).length + let indexEnd = (text.slice(0, selectionEnd).match(/,/g) || []).length + let range = indexEnd - indexStart + 1 + + if (isLeft && indexStart > 0) { + items.splice(indexStart - 1, 0, ...items.splice(indexStart, range)) + target.value = items.join() + target.selectionStart = items.slice(0, indexStart - 1).join().length + (indexStart == 1 ? 0 : 1) + target.selectionEnd = items.slice(0, indexEnd).join().length + } else if (isRight && indexEnd < items.length - 1) { + items.splice(indexStart + 1, 0, ...items.splice(indexStart, range)) + target.value = items.join() + target.selectionStart = items.slice(0, indexStart + 1).join().length + 1 + target.selectionEnd = items.slice(0, indexEnd + 2).join().length + } + + event.preventDefault() + updateInput(target) +} + +addEventListener('keydown', (event) => { + keyupEditOrder(event); +}); From dafe5193633e4ab27aad74de63ea5fcc3d31aba8 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Mon, 22 May 2023 21:23:39 -0600 Subject: [PATCH 004/178] Fix lint errors --- javascript/edit-order.js | 50 ++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/javascript/edit-order.js b/javascript/edit-order.js index e924d419..11703350 100644 --- a/javascript/edit-order.js +++ b/javascript/edit-order.js @@ -1,34 +1,34 @@ -function keyupEditOrder(event){ - let target = event.originalTarget || event.composedPath()[0]; - if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return; - if (!event.metaKey && !event.ctrlKey) return; +function keyupEditOrder(event) { + let target = event.originalTarget || event.composedPath()[0]; + if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return; + if (!event.metaKey && !event.ctrlKey) return; - let isLeft = event.key == "ArrowLeft" - let isRight = event.key == "ArrowRight" - if (!isLeft && !isRight) return; + let isLeft = event.key == "ArrowLeft"; + let isRight = event.key == "ArrowRight"; + if (!isLeft && !isRight) return; - let selectionStart = target.selectionStart; - let selectionEnd = target.selectionEnd; - let text = target.value; - let items = text.split(",") - let indexStart = (text.slice(0, selectionStart).match(/,/g) || []).length - let indexEnd = (text.slice(0, selectionEnd).match(/,/g) || []).length - let range = indexEnd - indexStart + 1 + let selectionStart = target.selectionStart; + let selectionEnd = target.selectionEnd; + let text = target.value; + let items = text.split(","); + let indexStart = (text.slice(0, selectionStart).match(/,/g) || []).length; + let indexEnd = (text.slice(0, selectionEnd).match(/,/g) || []).length; + let range = indexEnd - indexStart + 1; if (isLeft && indexStart > 0) { - items.splice(indexStart - 1, 0, ...items.splice(indexStart, range)) - target.value = items.join() - target.selectionStart = items.slice(0, indexStart - 1).join().length + (indexStart == 1 ? 0 : 1) - target.selectionEnd = items.slice(0, indexEnd).join().length + items.splice(indexStart - 1, 0, ...items.splice(indexStart, range)); + target.value = items.join(); + target.selectionStart = items.slice(0, indexStart - 1).join().length + (indexStart == 1 ? 0 : 1); + target.selectionEnd = items.slice(0, indexEnd).join().length; } else if (isRight && indexEnd < items.length - 1) { - items.splice(indexStart + 1, 0, ...items.splice(indexStart, range)) - target.value = items.join() - target.selectionStart = items.slice(0, indexStart + 1).join().length + 1 - target.selectionEnd = items.slice(0, indexEnd + 2).join().length + items.splice(indexStart + 1, 0, ...items.splice(indexStart, range)); + target.value = items.join(); + target.selectionStart = items.slice(0, indexStart + 1).join().length + 1; + target.selectionEnd = items.slice(0, indexEnd + 2).join().length; } - - event.preventDefault() - updateInput(target) + + event.preventDefault(); + updateInput(target); } addEventListener('keydown', (event) => { From 43bdaa2f0eda79c685792b06a2bd84c65806a48f Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Thu, 25 May 2023 18:49:28 -0600 Subject: [PATCH 005/178] Make ctrl+left/right optional --- javascript/edit-order.js | 1 + modules/shared.py | 1 + 2 files changed, 2 insertions(+) diff --git a/javascript/edit-order.js b/javascript/edit-order.js index 11703350..50f7fe37 100644 --- a/javascript/edit-order.js +++ b/javascript/edit-order.js @@ -1,4 +1,5 @@ function keyupEditOrder(event) { + if (!opts.keyedit_move) return; let target = event.originalTarget || event.composedPath()[0]; if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return; if (!event.metaKey && !event.ctrlKey) return; diff --git a/modules/shared.py b/modules/shared.py index b3508883..280c4451 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -405,6 +405,7 @@ options_templates.update(options_section(('ui', "User interface"), { "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"), + "keyedit_move": OptionInfo(False, "Ctrl+left/right moves prompt elements"), "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}), "hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}), "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), From ba70a220e3176153ba2a559acb9e5aa692dce7ca Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 5 Jun 2023 22:20:29 +0300 Subject: [PATCH 006/178] Remove a bunch of unused/vestigial code As found by Vulture and some eyes --- modules/api/api.py | 7 ------ modules/api/models.py | 4 --- modules/codeformer_model.py | 4 --- modules/devices.py | 7 ------ modules/generation_parameters_copypaste.py | 29 ---------------------- modules/hypernetworks/hypernetwork.py | 24 ------------------ modules/paths.py | 14 ----------- 7 files changed, 89 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 2e49526e..41cd7eca 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -32,13 +32,6 @@ import piexif import piexif.helper -def upscaler_to_index(name: str): - try: - return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) - except Exception as e: - raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e - - def script_name_to_index(name, scripts): try: return [script.title().lower() for script in scripts].index(name.lower()) diff --git a/modules/api/models.py b/modules/api/models.py index b3a745f0..b5683071 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -274,10 +274,6 @@ class PromptStyleItem(BaseModel): prompt: Optional[str] = Field(title="Prompt") negative_prompt: Optional[str] = Field(title="Negative Prompt") -class ArtistItem(BaseModel): - name: str = Field(title="Name") - score: float = Field(title="Score") - category: str = Field(title="Category") class EmbeddingItem(BaseModel): step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available") diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 4260b016..a01fe63d 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -15,7 +15,6 @@ model_dir = "Codeformer" model_path = os.path.join(models_path, model_dir) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' -have_codeformer = False codeformer = None @@ -125,9 +124,6 @@ def setup_model(dirname): return restored_img - global have_codeformer - have_codeformer = True - global codeformer codeformer = FaceRestorerCodeFormer(dirname) shared.face_restorers.append(codeformer) diff --git a/modules/devices.py b/modules/devices.py index 1ed6ffdc..620ed1a6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -15,13 +15,6 @@ def has_mps() -> bool: else: return mac_specific.has_mps -def extract_device_id(args, name): - for x in range(len(args)): - if name in args[x]: - return args[x + 1] - - return None - def get_cuda_device_string(): from modules import shared diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 1d02ffae..699b1a81 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -174,31 +174,6 @@ def send_image_and_dimensions(x): return img, w, h - -def find_hypernetwork_key(hypernet_name, hypernet_hash=None): - """Determines the config parameter name to use for the hypernet based on the parameters in the infotext. - - Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config - parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to. - - If the infotext has no hash, then a hypernet with the same name will be selected instead. - """ - hypernet_name = hypernet_name.lower() - if hypernet_hash is not None: - # Try to match the hash in the name - for hypernet_key in shared.hypernetworks.keys(): - result = re_hypernet_hash.search(hypernet_key) - if result is not None and result[1] == hypernet_hash: - return hypernet_key - else: - # Fall back to a hypernet with the same name - for hypernet_key in shared.hypernetworks.keys(): - if hypernet_key.lower().startswith(hypernet_name): - return hypernet_key - - return None - - def restore_old_hires_fix_params(res): """for infotexts that specify old First pass size parameter, convert it into width, height, and hr scale""" @@ -329,10 +304,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model return res -settings_map = {} - - - infotext_to_setting_name_mapping = [ ('Clip skip', 'CLIP_stop_at_last_layers', ), ('Conditional mask weight', 'inpainting_mask_weight'), diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 5d12b449..51941c11 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -353,17 +353,6 @@ def load_hypernetworks(names, multipliers=None): shared.loaded_hypernetworks.append(hypernetwork) -def find_closest_hypernetwork_name(search: str): - if not search: - return None - search = search.lower() - applicable = [name for name in shared.hypernetworks if search in name.lower()] - if not applicable: - return None - applicable = sorted(applicable, key=lambda name: len(name)) - return applicable[0] - - def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None): hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None) @@ -446,18 +435,6 @@ def statistics(data): return total_information, recent_information -def report_statistics(loss_info:dict): - keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x])) - for key in keys: - try: - print("Loss statistics for file " + key) - info, recent = statistics(list(loss_info[key])) - print(info) - print(recent) - except Exception as e: - print(e) - - def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): # Remove illegal characters from name. name = "".join( x for x in name if (x.isalnum() or x in "._- ")) @@ -770,7 +747,6 @@ Last saved image: {html.escape(last_saved_image)}
pbar.leave = False pbar.close() hypernetwork.eval() - #report_statistics(loss_dict) sd_hijack_checkpoint.remove() diff --git a/modules/paths.py b/modules/paths.py index 5171df4f..bada804e 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -38,17 +38,3 @@ for d, must_exist, what, options in path_dirs: else: sys.path.append(d) paths[what] = d - - -class Prioritize: - def __init__(self, name): - self.name = name - self.path = None - - def __enter__(self): - self.path = sys.path.copy() - sys.path = [paths[self.name]] + sys.path - - def __exit__(self, exc_type, exc_val, exc_tb): - sys.path = self.path - self.path = None From 8ca34ad6d8cc2502403b3b96bb811366bc13c076 Mon Sep 17 00:00:00 2001 From: Su Wei Date: Fri, 9 Jun 2023 13:14:20 +0800 Subject: [PATCH 007/178] add model exists status check to modeuls/api/api.py , /sdapi/v1/options [POST] --- modules/api/api.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index eee99bbb..56b7858d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -22,7 +22,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights +from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights,checkpoint_alisases from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -515,6 +515,11 @@ class Api: def set_config(self, req: Dict[str, Any]): for k, v in req.items(): + if k == "sd_model_checkpoint": + checkpoint_info = checkpoint_alisases.get(v, None) + if checkpoint_info is None: + print(f"model [{v}] not founded, skip config saving process") + return shared.opts.set(k, v) shared.opts.save(shared.config_filename) From 9142be0a0d8cea37cf1ae86c17fc7dcb161d9a42 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 10 Jun 2023 23:36:34 +0900 Subject: [PATCH 008/178] quit restart --- modules/api/api.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 2e49526e..317c809f 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -208,6 +208,8 @@ class Api: self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) + self.add_api_route("/sdapi/v1/quit-webui", self.quit_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/restart-webui", self.restart_webui, methods=["POST"]) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -715,3 +717,10 @@ class Api: def launch(self, server_name, port): self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0) + + def quit_webui(self): + restart.stop_program() + + def restart_webui(self): + if restart.is_restartable(): + restart.restart_program() From 7e2d39a2d158d1426321686b05d31a3ea694a99e Mon Sep 17 00:00:00 2001 From: Su Wei Date: Mon, 12 Jun 2023 15:22:49 +0800 Subject: [PATCH 009/178] update model checkpoint switch code --- modules/api/api.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 56b7858d..7d7dfe9a 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -514,12 +514,11 @@ class Api: return options def set_config(self, req: Dict[str, Any]): + checkpoint_key="sd_model_checkpoint" + if checkpoint_key in req and str(req[checkpoint_key]) not in checkpoint_alisases: + raise RuntimeError(f"model {v!r} not found") + for k, v in req.items(): - if k == "sd_model_checkpoint": - checkpoint_info = checkpoint_alisases.get(v, None) - if checkpoint_info is None: - print(f"model [{v}] not founded, skip config saving process") - return shared.opts.set(k, v) shared.opts.save(shared.config_filename) From b9664ab6154818680ee25920e229b808a3cdec68 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 12 Jun 2023 18:15:27 +0900 Subject: [PATCH 010/178] move _stop route to api --- modules/api/api.py | 11 +++++++++-- webui.py | 7 ------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 317c809f..cb1cde78 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -208,8 +208,11 @@ class Api: self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) - self.add_api_route("/sdapi/v1/quit-webui", self.quit_webui, methods=["POST"]) - self.add_api_route("/sdapi/v1/restart-webui", self.restart_webui, methods=["POST"]) + + if shared.cmd_opts.add_stop_route: + self.add_api_route("/sdapi/v1/quit-webui", self.quit_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/restart-webui", self.restart_webui, methods=["POST"]) + self.add_api_route("/_stop", self.stop_route, methods=["POST"]) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -724,3 +727,7 @@ class Api: def restart_webui(self): if restart.is_restartable(): restart.restart_program() + + def stop_route(request): + shared.state.server_command = "stop" + return Response("Stopping.") diff --git a/webui.py b/webui.py index 136d036d..ae6dc9fb 100644 --- a/webui.py +++ b/webui.py @@ -362,11 +362,6 @@ def api_only(): api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) -def stop_route(request): - shared.state.server_command = "stop" - return Response("Stopping.") - - def webui(): launch_api = cmd_opts.api initialize() @@ -404,8 +399,6 @@ def webui(): "redoc_url": "/redoc", }, ) - if cmd_opts.add_stop_route: - app.add_route("/_stop", stop_route, methods=["POST"]) # after initial launch, disable --autolaunch for subsequent restarts cmd_opts.autolaunch = False From d80962681ae0f3456b1c2776f68c5c838d782786 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 12 Jun 2023 18:21:01 +0900 Subject: [PATCH 011/178] remove fastapi.Response --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.py b/webui.py index ae6dc9fb..bad29f28 100644 --- a/webui.py +++ b/webui.py @@ -11,7 +11,7 @@ import json from threading import Thread from typing import Iterable -from fastapi import FastAPI, Response +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from packaging import version From 89352a2f52c6be51318192cedd86c8a342966a49 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 09:34:26 +0300 Subject: [PATCH 012/178] Move `load_file_from_url` to modelloader --- extensions-builtin/LDSR/scripts/ldsr_model.py | 7 ++--- .../ScuNET/scripts/scunet_model.py | 5 ++-- .../SwinIR/scripts/swinir_model.py | 8 +++-- modules/esrgan_model.py | 4 +-- modules/modelloader.py | 29 +++++++++++++++++-- modules/realesrgan_model.py | 4 +-- 6 files changed, 39 insertions(+), 18 deletions(-) diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index dbd6d331..bf9b6de2 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -1,7 +1,6 @@ import os -from basicsr.utils.download_util import load_file_from_url - +from modules.modelloader import load_file_from_url from modules.upscaler import Upscaler, UpscalerData from ldsr_model_arch import LDSR from modules import shared, script_callbacks, errors @@ -43,9 +42,9 @@ class UpscalerLDSR(Upscaler): if local_safetensors_path is not None and os.path.exists(local_safetensors_path): model = local_safetensors_path else: - model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True) + model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt") - yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True) + yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml") try: return LDSR(model, yaml) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 85b4505f..2785b551 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -6,12 +6,11 @@ import numpy as np import torch from tqdm import tqdm -from basicsr.utils.download_util import load_file_from_url - import modules.upscaler from modules import devices, modelloader, script_callbacks, errors from scunet_model_arch import SCUNet as net +from modules.modelloader import load_file_from_url from modules.shared import opts @@ -120,7 +119,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): def load_model(self, path: str): device = devices.get_device_for('scunet') if "http" in path: - filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True) + filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 1c7bf325..a5b0e2eb 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -3,7 +3,6 @@ import os import numpy as np import torch from PIL import Image -from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared @@ -50,8 +49,11 @@ class UpscalerSwinIR(Upscaler): def load_model(self, path, scale=4): if "http" in path: - dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth") - filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True) + filename = modelloader.load_file_from_url( + url=path, + model_dir=self.model_download_path, + file_name=f"{self.model_name.replace(' ', '_')}.pth", + ) else: filename = path if filename is None or not os.path.exists(filename): diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 2fced999..f1a98c07 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -3,7 +3,6 @@ import os import numpy as np import torch from PIL import Image -from basicsr.utils.download_util import load_file_from_url import modules.esrgan_model_arch as arch from modules import modelloader, images, devices @@ -152,11 +151,10 @@ class UpscalerESRGAN(Upscaler): def load_model(self, path: str): if "http" in path: - filename = load_file_from_url( + filename = modelloader.load_file_from_url( url=self.model_url, model_dir=self.model_download_path, file_name=f"{self.model_name}.pth", - progress=True, ) else: filename = path diff --git a/modules/modelloader.py b/modules/modelloader.py index be23071a..a69c8a4f 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import importlib @@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale from modules.paths import script_path, models_path +def load_file_from_url( + url: str, + *, + model_dir: str, + progress: bool = True, + file_name: str | None = None, +) -> str: + """Download a file from `url` into `model_dir`, using the file present if possible. + + Returns the path to the downloaded file. + """ + os.makedirs(model_dir, exist_ok=True) + if not file_name: + parts = urlparse(url) + file_name = os.path.basename(parts.path) + cached_file = os.path.abspath(os.path.join(model_dir, file_name)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + from torch.hub import download_url_to_file + download_url_to_file(url, cached_file, progress=progress) + return cached_file + + def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list: """ A one-and done loader to try finding the desired models in specified directories. @@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None if model_url is not None and len(output) == 0: if download_name is not None: - from basicsr.utils.download_util import load_file_from_url - dl = load_file_from_url(model_url, places[0], True, download_name) - output.append(dl) + output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name)) else: output.append(model_url) diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 2d27b321..0d9c2e48 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -2,7 +2,6 @@ import os import numpy as np from PIL import Image -from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer from modules.upscaler import Upscaler, UpscalerData @@ -10,6 +9,7 @@ from modules.shared import cmd_opts, opts from modules import modelloader, errors + class UpscalerRealESRGAN(Upscaler): def __init__(self, path): self.name = "RealESRGAN" @@ -71,7 +71,7 @@ class UpscalerRealESRGAN(Upscaler): return None if info.local_data_path.startswith("http"): - info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True) + info.local_data_path = modelloader.load_file_from_url(info.data_path, model_dir=self.model_download_path) return info except Exception: From 0afbc0c2355ead3a0ce7149a6d678f1f2e2fbfee Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 09:41:36 +0300 Subject: [PATCH 013/178] Fix up `if "http" in ...:` to be more sensible startswiths --- extensions-builtin/ScuNET/scripts/scunet_model.py | 4 ++-- extensions-builtin/SwinIR/scripts/swinir_model.py | 4 ++-- modules/esrgan_model.py | 4 ++-- modules/gfpgan_model.py | 2 +- modules/modelloader.py | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 2785b551..64f50829 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -27,7 +27,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scalers = [] add_model2 = True for file in model_paths: - if "http" in file: + if file.startswith("http"): name = self.model_name else: name = modelloader.friendly_name(file) @@ -118,7 +118,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): def load_model(self, path: str): device = devices.get_device_for('scunet') - if "http" in path: + if path.startswith("http"): filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index a5b0e2eb..4551761d 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -27,7 +27,7 @@ class UpscalerSwinIR(Upscaler): scalers = [] model_files = self.find_models(ext_filter=[".pt", ".pth"]) for model in model_files: - if "http" in model: + if model.startswith("http"): name = self.model_name else: name = modelloader.friendly_name(model) @@ -48,7 +48,7 @@ class UpscalerSwinIR(Upscaler): return img def load_model(self, path, scale=4): - if "http" in path: + if path.startswith("http"): filename = modelloader.load_file_from_url( url=path, model_dir=self.model_download_path, diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index f1a98c07..0666a2c2 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -133,7 +133,7 @@ class UpscalerESRGAN(Upscaler): scaler_data = UpscalerData(self.model_name, self.model_url, self, 4) scalers.append(scaler_data) for file in model_paths: - if "http" in file: + if file.startswith("http"): name = self.model_name else: name = modelloader.friendly_name(file) @@ -150,7 +150,7 @@ class UpscalerESRGAN(Upscaler): return img def load_model(self, path: str): - if "http" in path: + if path.startswith("http"): filename = modelloader.load_file_from_url( url=self.model_url, model_dir=self.model_download_path, diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index e239a09d..804fb53d 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -25,7 +25,7 @@ def gfpgann(): return None models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") - if len(models) == 1 and "http" in models[0]: + if len(models) == 1 and models[0].startswith("http"): model_file = models[0] elif len(models) != 0: latest_file = max(models, key=os.path.getctime) diff --git a/modules/modelloader.py b/modules/modelloader.py index a69c8a4f..b2f0bb71 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -82,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None def friendly_name(file: str): - if "http" in file: + if file.startswith("http"): file = urlparse(file).path file = os.path.basename(file) From e3a973a68df3cfe13039dae33d19cf2c02a741e0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 09:45:07 +0300 Subject: [PATCH 014/178] Add TODO comments to sus model loads --- extensions-builtin/ScuNET/scripts/scunet_model.py | 1 + modules/esrgan_model.py | 1 + 2 files changed, 2 insertions(+) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 64f50829..da74a829 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -119,6 +119,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): def load_model(self, path: str): device = devices.get_device_for('scunet') if path.startswith("http"): + # TODO: this doesn't use `path` at all? filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 0666a2c2..a20e8d91 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -151,6 +151,7 @@ class UpscalerESRGAN(Upscaler): def load_model(self, path: str): if path.startswith("http"): + # TODO: this doesn't use `path` at all? filename = modelloader.load_file_from_url( url=self.model_url, model_dir=self.model_download_path, From bf67a5dcf44c3dbd88d1913478d4e02477915f33 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 10:38:51 +0300 Subject: [PATCH 015/178] Upscaler.load_model: don't return None, just use exceptions --- extensions-builtin/LDSR/scripts/ldsr_model.py | 13 +++--- .../ScuNET/scripts/scunet_model.py | 16 +++----- .../SwinIR/scripts/swinir_model.py | 40 +++++++++---------- modules/esrgan_model.py | 14 +++---- modules/realesrgan_model.py | 33 +++++++-------- 5 files changed, 52 insertions(+), 64 deletions(-) diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index bf9b6de2..bd78dece 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -46,16 +46,13 @@ class UpscalerLDSR(Upscaler): yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml") - try: - return LDSR(model, yaml) - except Exception: - errors.report("Error importing LDSR", exc_info=True) - return None + return LDSR(model, yaml) def do_upscale(self, img, path): - ldsr = self.load_model(path) - if ldsr is None: - print("NO LDSR!") + try: + ldsr = self.load_model(path) + except Exception: + errors.report(f"Failed loading LDSR model {path}", exc_info=True) return img ddim_steps = shared.opts.ldsr_steps return ldsr.super_resolution(img, ddim_steps, self.scale) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index da74a829..ffef26b2 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -1,4 +1,3 @@ -import os.path import sys import PIL.Image @@ -8,7 +7,7 @@ from tqdm import tqdm import modules.upscaler from modules import devices, modelloader, script_callbacks, errors -from scunet_model_arch import SCUNet as net +from scunet_model_arch import SCUNet from modules.modelloader import load_file_from_url from modules.shared import opts @@ -88,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler): torch.cuda.empty_cache() - model = self.load_model(selected_file) - if model is None: - print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr) + try: + model = self.load_model(selected_file) + except Exception as e: + print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr) return img device = devices.get_device_for('scunet') @@ -123,11 +123,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: - print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr) - return None - - model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) + model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) model.load_state_dict(torch.load(filename), strict=True) model.eval() for _, v in model.named_parameters(): diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 4551761d..3ce622d9 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,4 +1,4 @@ -import os +import sys import numpy as np import torch @@ -7,8 +7,8 @@ from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared from modules.shared import opts, state -from swinir_model_arch import SwinIR as net -from swinir_model_arch_v2 import Swin2SR as net2 +from swinir_model_arch import SwinIR +from swinir_model_arch_v2 import Swin2SR from modules.upscaler import Upscaler, UpscalerData @@ -36,8 +36,10 @@ class UpscalerSwinIR(Upscaler): self.scalers = scalers def do_upscale(self, img, model_file): - model = self.load_model(model_file) - if model is None: + try: + model = self.load_model(model_file) + except Exception as e: + print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) return img model = model.to(device_swinir, dtype=devices.dtype) img = upscale(img, model) @@ -56,25 +58,23 @@ class UpscalerSwinIR(Upscaler): ) else: filename = path - if filename is None or not os.path.exists(filename): - return None if filename.endswith(".v2.pth"): - model = net2( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6], - embed_dim=180, - num_heads=[6, 6, 6, 6, 6, 6], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="1conv", + model = Swin2SR( + upscale=scale, + in_chans=3, + img_size=64, + window_size=8, + img_range=1.0, + depths=[6, 6, 6, 6, 6, 6], + embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], + mlp_ratio=2, + upsampler="nearest+conv", + resi_connection="1conv", ) params = None else: - model = net( + model = SwinIR( upscale=scale, in_chans=3, img_size=64, diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a20e8d91..02a1727d 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,4 +1,4 @@ -import os +import sys import numpy as np import torch @@ -6,9 +6,8 @@ from PIL import Image import modules.esrgan_model_arch as arch from modules import modelloader, images, devices -from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts - +from modules.upscaler import Upscaler, UpscalerData def mod2normal(state_dict): @@ -142,8 +141,10 @@ class UpscalerESRGAN(Upscaler): self.scalers.append(scaler_data) def do_upscale(self, img, selected_model): - model = self.load_model(selected_model) - if model is None: + try: + model = self.load_model(selected_model) + except Exception as e: + print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr) return img model.to(devices.device_esrgan) img = esrgan_upscale(model, img) @@ -159,9 +160,6 @@ class UpscalerESRGAN(Upscaler): ) else: filename = path - if not os.path.exists(filename) or filename is None: - print(f"Unable to load {self.model_path} from {filename}") - return None state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 0d9c2e48..0700b853 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -9,7 +9,6 @@ from modules.shared import cmd_opts, opts from modules import modelloader, errors - class UpscalerRealESRGAN(Upscaler): def __init__(self, path): self.name = "RealESRGAN" @@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler): if not self.enable: return img - info = self.load_model(path) - if not os.path.exists(info.local_data_path): - print(f"Unable to load RealESRGAN model: {info.name}") + try: + info = self.load_model(path) + except Exception: + errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) return img upsampler = RealESRGANer( @@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler): return image def load_model(self, path): - try: - info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None) - - if info is None: - print(f"Unable to find model info: {path}") - return None - - if info.local_data_path.startswith("http"): - info.local_data_path = modelloader.load_file_from_url(info.data_path, model_dir=self.model_download_path) - - return info - except Exception: - errors.report("Error making Real-ESRGAN models list", exc_info=True) - return None + for scaler in self.scalers: + if scaler.data_path == path: + if scaler.local_data_path.startswith("http"): + scaler.local_data_path = modelloader.load_file_from_url( + scaler.data_path, + model_dir=self.model_download_path, + ) + if not os.path.exists(scaler.local_data_path): + raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}") + return scaler + raise ValueError(f"Unable to find model info: {path}") def load_models(self, _): return get_realesrgan_models(self) From 2667f47ffbf7c641a7e77abbdddf5e81bf144199 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 13 Jun 2023 13:00:05 +0300 Subject: [PATCH 016/178] Remove stray space from SwinIR model URL --- extensions-builtin/SwinIR/scripts/swinir_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 3ce622d9..c6bc53a8 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -11,6 +11,7 @@ from swinir_model_arch import SwinIR from swinir_model_arch_v2 import Swin2SR from modules.upscaler import Upscaler, UpscalerData +SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" device_swinir = devices.get_device_for('swinir') @@ -18,9 +19,7 @@ device_swinir = devices.get_device_for('swinir') class UpscalerSwinIR(Upscaler): def __init__(self, dirname): self.name = "SwinIR" - self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \ - "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \ - "-L_x4_GAN.pth " + self.model_url = SWINIR_MODEL_URL self.model_name = "SwinIR 4x" self.user_path = dirname super().__init__() From 8ce9b36e0fe51002e72f90ec4dbdc53b564c8fad Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 13 Jun 2023 13:07:06 +0300 Subject: [PATCH 017/178] Upgrade ruff to 272 --- .github/workflows/on_pull_request.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/on_pull_request.yaml b/.github/workflows/on_pull_request.yaml index 7b7219fd..8ebf5918 100644 --- a/.github/workflows/on_pull_request.yaml +++ b/.github/workflows/on_pull_request.yaml @@ -18,7 +18,7 @@ jobs: # not to have GHA download an (at the time of writing) 4 GB cache # of PyTorch and other dependencies. - name: Install Ruff - run: pip install ruff==0.0.265 + run: pip install ruff==0.0.272 - name: Run Ruff run: ruff . lint-js: From d8071647760a2213aaf33a533addb4d84ba86816 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 13 Jun 2023 13:07:39 +0300 Subject: [PATCH 018/178] textual_inversion/logging.py: clean up duplicate key from sets (and sort them) (Ruff B033) --- modules/textual_inversion/logging.py | 48 +++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py index 734a4b6f..45823eb1 100644 --- a/modules/textual_inversion/logging.py +++ b/modules/textual_inversion/logging.py @@ -2,11 +2,51 @@ import datetime import json import os -saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"} -saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"} -saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"} +saved_params_shared = { + "batch_size", + "clip_grad_mode", + "clip_grad_value", + "create_image_every", + "data_root", + "gradient_step", + "initial_step", + "latent_sampling_method", + "learn_rate", + "log_directory", + "model_hash", + "model_name", + "num_of_dataset_images", + "steps", + "template_file", + "training_height", + "training_width", +} +saved_params_ti = { + "embedding_name", + "num_vectors_per_token", + "save_embedding_every", + "save_image_with_stored_embedding", +} +saved_params_hypernet = { + "activation_func", + "add_layer_norm", + "hypernetwork_name", + "layer_structure", + "save_hypernetwork_every", + "use_dropout", + "weight_init", +} saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet -saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"} +saved_params_previews = { + "preview_cfg_scale", + "preview_height", + "preview_negative_prompt", + "preview_prompt", + "preview_sampler_index", + "preview_seed", + "preview_steps", + "preview_width", +} def save_settings_to_file(log_directory, all_params): From 5be6c026f55760039b3ebb284cc2ce85586be4ac Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 14 Jun 2023 18:51:47 +0900 Subject: [PATCH 019/178] rename routes --- modules/api/api.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index cb1cde78..5ea1d21c 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -210,9 +210,9 @@ class Api: self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) if shared.cmd_opts.add_stop_route: - self.add_api_route("/sdapi/v1/quit-webui", self.quit_webui, methods=["POST"]) - self.add_api_route("/sdapi/v1/restart-webui", self.restart_webui, methods=["POST"]) - self.add_api_route("/_stop", self.stop_route, methods=["POST"]) + self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/server-terminate", self.terminate_webui, methods=["POST"]) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -721,13 +721,13 @@ class Api: self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0) - def quit_webui(self): + def kill_webui(self): restart.stop_program() def restart_webui(self): if restart.is_restartable(): restart.restart_program() - def stop_route(request): + def terminate_webui(request): shared.state.server_command = "stop" return Response("Stopping.") From 6387f0e85d207705e0a68178bbf71aa81ba82256 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 14 Jun 2023 18:51:54 +0900 Subject: [PATCH 020/178] update workflow kill test server --- .github/workflows/run_tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 226cf759..394811ac 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -50,7 +50,7 @@ jobs: python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test - name: Kill test server if: always() - run: curl -vv -XPOST http://127.0.0.1:7860/_stop && sleep 10 + run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-terminate && sleep 10 - name: Show coverage run: | python -m coverage combine .coverage* From 49fb2a337661d1b9a80de8ff35a640083fa98d2f Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 14 Jun 2023 19:52:12 +0900 Subject: [PATCH 021/178] response 501 if not a able to restart --- modules/api/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/api/api.py b/modules/api/api.py index 5ea1d21c..4dc48a03 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -727,6 +727,7 @@ class Api: def restart_webui(self): if restart.is_restartable(): restart.restart_program() + return Response(status_code=501) def terminate_webui(request): shared.state.server_command = "stop" From 6091c4e4aa32b674f8ec755e6bd58989f09b08c5 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 14 Jun 2023 19:53:08 +0900 Subject: [PATCH 022/178] terminate -> stop --- .github/workflows/run_tests.yaml | 2 +- modules/api/api.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 394811ac..96546011 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -50,7 +50,7 @@ jobs: python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test - name: Kill test server if: always() - run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-terminate && sleep 10 + run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10 - name: Show coverage run: | python -m coverage combine .coverage* diff --git a/modules/api/api.py b/modules/api/api.py index 4dc48a03..80d45ca7 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -212,7 +212,7 @@ class Api: if shared.cmd_opts.add_stop_route: self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"]) - self.add_api_route("/sdapi/v1/server-terminate", self.terminate_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"]) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -729,6 +729,6 @@ class Api: restart.restart_program() return Response(status_code=501) - def terminate_webui(request): + def stop_webui(request): shared.state.server_command = "stop" return Response("Stopping.") From fa9d2ac2ff7cf6fbc73525190bd7fde724ec1fb3 Mon Sep 17 00:00:00 2001 From: Jared Deckard Date: Wed, 14 Jun 2023 13:53:13 -0500 Subject: [PATCH 023/178] Fix gradio special args in the call queue --- modules/call_queue.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/call_queue.py b/modules/call_queue.py index 447bb764..64ebf868 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -1,3 +1,4 @@ +from functools import wraps import html import sys import threading @@ -20,6 +21,7 @@ def wrap_queued_call(func): def wrap_gradio_gpu_call(func, extra_outputs=None): + @wraps(func) def f(*args, **kwargs): # if the first argument is a string that says "task(...)", it is treated as a job id @@ -47,6 +49,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def wrap_gradio_call(func, extra_outputs=None, add_stats=False): + @wraps(func) def f(*args, extra_outputs_array=extra_outputs, **kwargs): run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats if run_memmon: From 376f793bded0e7df40eafcacfd086e4e4d218bc5 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 15 Jun 2023 04:23:52 +0900 Subject: [PATCH 024/178] git clone show progress --- modules/launch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 609a181e..97539e68 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -147,10 +147,10 @@ def git_clone(url, dir, name, commithash=None): return run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") - run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") + run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True) return - run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") + run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True) if commithash is not None: run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") From 1d7c51fb9f757b5dcdc506f8fc003e6047151567 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Wed, 14 Jun 2023 13:07:22 -0700 Subject: [PATCH 025/178] WEBUI.SH Navi 3 Support Navi 3 card now defaults to nightly torch to utilize rocm 5.5 for out-of-the-box support. https://download.pytorch.org/whl/nightly/ While its not yet on the main pytorch "get started" site, it still seems perfectly indexable via pip which is all we need. With this I'm able to clone a fresh repo and immediately run ./webui.sh on my 7900 XTX without any problems. --- webui.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/webui.sh b/webui.sh index 5c8d977c..c889c55e 100755 --- a/webui.sh +++ b/webui.sh @@ -131,6 +131,10 @@ case "$gpu_info" in ;; *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 ;; + *"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \ + export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.5" + # Navi 3 needs at least 5.5 which is only on the nightly chain + ;; *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 printf "\n%s\n" "${delimiter}" printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half" From d3c86e5178725b11a4679097f0aefb0a9fc90014 Mon Sep 17 00:00:00 2001 From: Jared Deckard Date: Wed, 14 Jun 2023 14:03:44 -0500 Subject: [PATCH 026/178] Note the Gradio user in the Exif data --- modules/img2img.py | 5 ++++- modules/processing.py | 3 +++ modules/txt2img.py | 6 ++++-- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index d704bf90..83bd7857 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -2,6 +2,7 @@ import os import numpy as np from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError +import gradio as gr from modules import sd_samplers from modules.generation_parameters_copypaste import create_override_settings_dict @@ -78,7 +79,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 @@ -160,6 +161,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s p.scripts = modules.scripts.scripts_img2img p.script_args = args + p.user = request.username + if shared.cmd_opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/processing.py b/modules/processing.py index d22b353f..3e8d153e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -180,6 +180,8 @@ class StableDiffusionProcessing: self.uc = None self.c = None + self.user = None + @property def sd_model(self): return shared.sd_model @@ -578,6 +580,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, **p.extra_generation_params, "Version": program_version() if opts.add_version_to_infotext else None, + "User": p.user, } generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) diff --git a/modules/txt2img.py b/modules/txt2img.py index 2e7d202d..6aa79f23 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -4,10 +4,10 @@ from modules.generation_parameters_copypaste import create_override_settings_dic from modules.shared import opts, cmd_opts import modules.shared as shared from modules.ui import plaintext_to_html +import gradio as gr - -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) p = processing.StableDiffusionProcessingTxt2Img( @@ -48,6 +48,8 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step p.scripts = modules.scripts.scripts_txt2img p.script_args = args + p.user = request.username + if cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) From 9ec2ba2d28bb0d8f01e19e2919b7bf2e3e864773 Mon Sep 17 00:00:00 2001 From: XiaoMeng Mai Date: Thu, 15 Jun 2023 22:43:09 +0800 Subject: [PATCH 027/178] Add github mirror for the download extension --- modules/ui_extensions.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 4379a641..6c717313 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -322,11 +322,21 @@ def normalize_git_url(url): return url -def install_extension_from_url(dirname, url, branch_name=None): +def install_extension_from_url(dirname, proxy, url, branch_name=None): check_access() assert url, 'No URL specified' + proxy_list = { + "none": "", + "ghproxy": "https://ghproxy.com/", + "hub.yzuu.cf": "https://hub.yzuu.cf/", + "hub.njuu.cf": "https://hub.njuu.cf/", + "hub.nuaa.cf": "https://hub.nuaa.cf/", + } + + url = proxy_list[proxy] + url + if dirname is None or dirname == "": *parts, last_part = url.split('/') last_part = normalize_git_url(last_part) @@ -346,12 +356,12 @@ def install_extension_from_url(dirname, url, branch_name=None): shutil.rmtree(tmpdir, True) if not branch_name: # if no branch is specified, use the default branch - with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo: + with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], verbose=False) as repo: repo.remote().fetch() for submodule in repo.submodules: submodule.update() else: - with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo: + with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name, verbose=False) as repo: repo.remote().fetch() for submodule in repo.submodules: submodule.update() @@ -593,6 +603,12 @@ def create_ui(): ) with gr.TabItem("Install from URL", id="install_from_url"): + + install_proxy = gr.Radio( + label="Install Proxy", choices=["none", "ghproxy", "hub.nuaa.cf","hub.yzuu.cf","hub.njuu.cf"], value="none", + info="If you can't access github.com, you can use a proxy to install extensions from github.com" + ) + install_url = gr.Text(label="URL for extension's git repository") install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch") install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") @@ -601,7 +617,7 @@ def create_ui(): install_button.click( fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]), - inputs=[install_dirname, install_url, install_branch], + inputs=[install_dirname, install_proxy, install_url, install_branch], outputs=[install_url, extensions_table, install_result], ) From de022c4c80240a430a8099fb27a41aa505bf5b2f Mon Sep 17 00:00:00 2001 From: XiaoMeng Mai Date: Thu, 15 Jun 2023 22:59:46 +0800 Subject: [PATCH 028/178] Update code style --- modules/ui_extensions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 6c717313..e4423a06 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -336,7 +336,6 @@ def install_extension_from_url(dirname, proxy, url, branch_name=None): } url = proxy_list[proxy] + url - if dirname is None or dirname == "": *parts, last_part = url.split('/') last_part = normalize_git_url(last_part) @@ -603,9 +602,8 @@ def create_ui(): ) with gr.TabItem("Install from URL", id="install_from_url"): - install_proxy = gr.Radio( - label="Install Proxy", choices=["none", "ghproxy", "hub.nuaa.cf","hub.yzuu.cf","hub.njuu.cf"], value="none", + label="Install Proxy", choices=["none", "ghproxy", "hub.nuaa.cf","hub.yzuu.cf","hub.njuu.cf"], value="none", info="If you can't access github.com, you can use a proxy to install extensions from github.com" ) From 8f18e672439fa1926717df2c938e7089149f3a8b Mon Sep 17 00:00:00 2001 From: Jared Deckard Date: Thu, 15 Jun 2023 10:53:16 -0500 Subject: [PATCH 029/178] Add a user pattern to the filename generator --- modules/images.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/images.py b/modules/images.py index 40efc96c..92b924ef 100644 --- a/modules/images.py +++ b/modules/images.py @@ -359,6 +359,7 @@ class FilenameGenerator: 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt..] 'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"], 'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT, + 'user': lambda self: self.p.user, } default_time_format = '%Y%m%d%H%M%S' From f603275d84301b5ee952683e951dd1aad72ba615 Mon Sep 17 00:00:00 2001 From: Jared Deckard Date: Thu, 15 Jun 2023 10:55:53 -0500 Subject: [PATCH 030/178] Add an opt-in infotext user name setting --- modules/processing.py | 2 +- modules/shared.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 3e8d153e..a0cc8db2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -580,7 +580,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, **p.extra_generation_params, "Version": program_version() if opts.add_version_to_infotext else None, - "User": p.user, + "User": p.user if opts.add_user_name_to_info else None, } generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) diff --git a/modules/shared.py b/modules/shared.py index 271a062d..4c639a21 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -496,6 +496,7 @@ options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('infotext', "Infotext"), { "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), + "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"), "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"), "disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), })) From e9bd18c57bd83363d38c7409263fe87f3ed3a7f0 Mon Sep 17 00:00:00 2001 From: XiaoMeng Mai Date: Fri, 16 Jun 2023 00:09:54 +0800 Subject: [PATCH 031/178] Update call method --- modules/ui_extensions.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index e4423a06..2b542e59 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -330,12 +330,16 @@ def install_extension_from_url(dirname, proxy, url, branch_name=None): proxy_list = { "none": "", "ghproxy": "https://ghproxy.com/", - "hub.yzuu.cf": "https://hub.yzuu.cf/", - "hub.njuu.cf": "https://hub.njuu.cf/", - "hub.nuaa.cf": "https://hub.nuaa.cf/", + "yzuu": "hub.yzuu.cf", + "njuu": "hub.njuu.cf", + "nuaa": "hub.nuaa.cf", } - url = proxy_list[proxy] + url + if proxy in ['yzuu', 'njuu', 'nuaa']: + url = url.replace('github.com', proxy_list[proxy]) + elif proxy == 'ghproxy': + url = proxy_list[proxy] + url + if dirname is None or dirname == "": *parts, last_part = url.split('/') last_part = normalize_git_url(last_part) @@ -603,7 +607,7 @@ def create_ui(): with gr.TabItem("Install from URL", id="install_from_url"): install_proxy = gr.Radio( - label="Install Proxy", choices=["none", "ghproxy", "hub.nuaa.cf","hub.yzuu.cf","hub.njuu.cf"], value="none", + label="Install Proxy", choices=["none", "ghproxy", "nuaa", "yzuu", "njuu"], value="none", info="If you can't access github.com, you can use a proxy to install extensions from github.com" ) From 41363e0d27bbaa0c84eebe3c7c8451075390ec4e Mon Sep 17 00:00:00 2001 From: dhwz Date: Fri, 16 Jun 2023 18:10:15 +0200 Subject: [PATCH 032/178] fix very slow loading speed of .safetensors files --- modules/sd_models.py | 7 +++++-- modules/shared.py | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 918f6fd6..d9ac675b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -247,8 +247,11 @@ def read_metadata_from_safetensors(filename): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": - device = map_location or shared.weight_load_location or devices.get_optimal_device_name() - pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) + if not shared.opts.disable_mmap_load_safetensors: + device = map_location or shared.weight_load_location or devices.get_optimal_device_name() + pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) + else: + pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read()) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) diff --git a/modules/shared.py b/modules/shared.py index 91c31d55..6b0ccac1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -376,6 +376,7 @@ options_templates.update(options_section(('system', "System"), { "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), + "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files (fixes very slow loading speed in some cases)."), })) options_templates.update(options_section(('training', "Training"), { From 373ff5a217eca33607abb692b9ebfa38abb7fe33 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Fri, 16 Jun 2023 15:17:17 -0400 Subject: [PATCH 033/178] :bug: Allow Script to have metaclass --- modules/scripts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/scripts.py b/modules/scripts.py index c902804b..52682fbf 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -2,6 +2,7 @@ import os import re import sys import traceback +import inspect from collections import namedtuple import gradio as gr @@ -238,7 +239,7 @@ def load_scripts(): def register_scripts_from_module(module): for script_class in module.__dict__.values(): - if type(script_class) != type: + if not inspect.isclass(script_class): continue if issubclass(script_class, Script): From 2e1710d88edc1e1a08b01c063fa386b50e5abc30 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 18 Jun 2023 14:07:41 +0900 Subject: [PATCH 034/178] update the description of --add-stop-rout --- modules/cmd_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index de905caa..624dcb4f 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -106,4 +106,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') -parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') +parser.add_argument('--add-stop-route', action='store_true', help='enable server stop/restart/kill via api') From d2ccdcdc97f8e8b0a1a63f2031716b3866c7b53b Mon Sep 17 00:00:00 2001 From: George Gu Date: Mon, 19 Jun 2023 10:16:18 +0800 Subject: [PATCH 035/178] fix: adding elem_id for img2img resize to and resize by tabs --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 361f596e..ce019f9c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -789,7 +789,7 @@ def create_ui(): selected_scale_tab = gr.State(value=0) with gr.Tabs(): - with gr.Tab(label="Resize to") as tab_scale_to: + with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to: with FormRow(): with gr.Column(elem_id="img2img_column_size", scale=4): width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") @@ -798,7 +798,7 @@ def create_ui(): res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn") - with gr.Tab(label="Resize by") as tab_scale_by: + with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by: scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") with FormRow(): From 27e9e3f6fa41c10eb5256662ddf3643dee933810 Mon Sep 17 00:00:00 2001 From: stablegeniusdiffuser Date: Mon, 19 Jun 2023 20:36:44 +0200 Subject: [PATCH 036/178] Add use_main_prompt parameter to use proper metadata for batch run grids or individual images --- modules/processing.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 8da73884..1d97e95e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -549,7 +549,7 @@ def program_version(): return res -def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0): +def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False): index = position_in_batch + iteration * p.batch_size clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) @@ -589,9 +589,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) + prompt_text = p.prompt if use_main_prompt else all_prompts[index] negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else "" - return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip() def process_images(p: StableDiffusionProcessing) -> Processed: @@ -663,8 +664,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] - def infotext(iteration=0, position_in_batch=0): - return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) + def infotext(iteration=0, position_in_batch=0, use_main_prompt=False): + return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt) if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() @@ -824,7 +825,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: grid = images.image_grid(output_images, p.batch_size) if opts.return_grid: - text = infotext() + text = infotext(use_main_prompt=True) infotexts.insert(0, text) if opts.enable_pnginfo: grid.info["parameters"] = text @@ -832,7 +833,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: index_of_first_image = 1 if opts.grid_save: - images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True) if not p.disable_extra_networks and p.extra_network_data: extra_networks.deactivate(p, p.extra_network_data) From 928bd42da46683315c9f4498f6fbd5c59279da18 Mon Sep 17 00:00:00 2001 From: Ferdinand Weynschenk Date: Tue, 20 Jun 2023 13:33:36 +0200 Subject: [PATCH 037/178] PNG info support at img2img batch --- modules/img2img.py | 38 +++++++++++++++++++++++++++++++++----- modules/ui.py | 7 +++++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index d704bf90..88e172ff 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -3,8 +3,8 @@ import os import numpy as np from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError -from modules import sd_samplers -from modules.generation_parameters_copypaste import create_override_settings_dict +from modules import sd_samplers, images as imgutil +from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state import modules.shared as shared @@ -13,7 +13,7 @@ from modules.ui import plaintext_to_html import modules.scripts -def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): +def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, output_dir, inpaint_mask_dir, args): processing.fix_seed(p) images = shared.listfiles(input_dir) @@ -34,6 +34,9 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): state.job_count = len(images) * p.n_iter + prompt = p.prompt + negative_prompt = p.negative_prompt + for i, image in enumerate(images): state.job = f"{i+1} out of {len(images)}" if state.skipped: @@ -59,6 +62,31 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): mask_image_path = inpaint_masks[0] mask_image = Image.open(mask_image_path) p.image_mask = mask_image + + if use_png_info: + try: + info_img = img + if png_info_dir: + info_img_path = os.path.join(png_info_dir, os.path.basename(image)) + info_img = Image.open(info_img_path) + geninfo, _ = imgutil.read_info_from_image(info_img) + parsed_parameters = parse_generation_parameters(geninfo) + if("Prompt" in png_info_props): + p.prompt = prompt + " " + parsed_parameters["Prompt"] + if("Negative prompt" in png_info_props): + p.negative_prompt = negative_prompt + " " + parsed_parameters["Negative prompt"] + if("Seed" in png_info_props): + p.seed = int(parsed_parameters["Seed"]) + if("CFG scale" in png_info_props): + p.cfg_scale = float(parsed_parameters["CFG scale"]) + if("Sampler" in png_info_props): + p.sampler_name = parsed_parameters["Sampler"] + if("Steps" in png_info_props): + p.steps = int(parsed_parameters["Steps"]) + except: + p.prompt = prompt + p.negative_prompt = negative_prompt + print(f"batch png info: using ui set prompts; failed to get png info for {image}") proc = modules.scripts.scripts_img2img.run(p, *args) if proc is None: @@ -78,7 +106,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 @@ -169,7 +197,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" - process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args) + process_batch(p, img2img_batch_use_png_info, img2img_batch_png_info_props, img2img_batch_png_info_dir, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args) processed = Processed(p, [], p.seed, "") else: diff --git a/modules/ui.py b/modules/ui.py index 361f596e..a79b0e6c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -751,6 +751,10 @@ def create_ui(): img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") + with gr.Accordion("PNG info"): + img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") + img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") + img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] @@ -943,6 +947,9 @@ def create_ui(): inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, + img2img_batch_use_png_info, + img2img_batch_png_info_props, + img2img_batch_png_info_dir, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, From 7ad48120d45e678b6343f7d95a1f97607858009a Mon Sep 17 00:00:00 2001 From: Ferdinand Weynschenk Date: Tue, 20 Jun 2023 13:50:02 +0200 Subject: [PATCH 038/178] use ui params when retreiving png info fails Don't want to interrupt the process since batches can be huge. This makes more sense to me than using the previous images parameters --- modules/img2img.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/modules/img2img.py b/modules/img2img.py index 88e172ff..e46a6fde 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -34,8 +34,13 @@ def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, outp state.job_count = len(images) * p.n_iter + # extract "default" params to use in case getting png info fails prompt = p.prompt negative_prompt = p.negative_prompt + seed = p.seed + cfg_scale = p.cfg_scale + sampler_name = p.sampler_name + steps = p.steps for i, image in enumerate(images): state.job = f"{i+1} out of {len(images)}" @@ -86,6 +91,10 @@ def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, outp except: p.prompt = prompt p.negative_prompt = negative_prompt + p.seed = seed + p.cfg_scale = cfg_scale + p.sampler_name = sampler_name + p.steps = steps print(f"batch png info: using ui set prompts; failed to get png info for {image}") proc = modules.scripts.scripts_img2img.run(p, *args) From c4c63dd5e4760c56405cef2e71abc5c3604c4578 Mon Sep 17 00:00:00 2001 From: Ferdinand Weynschenk Date: Tue, 20 Jun 2023 14:03:42 +0200 Subject: [PATCH 039/178] resolve linter --- modules/img2img.py | 7 ++++--- modules/ui.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index e46a6fde..f77dfd9f 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -67,7 +67,7 @@ def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, outp mask_image_path = inpaint_masks[0] mask_image = Image.open(mask_image_path) p.image_mask = mask_image - + if use_png_info: try: info_img = img @@ -88,14 +88,15 @@ def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, outp p.sampler_name = parsed_parameters["Sampler"] if("Steps" in png_info_props): p.steps = int(parsed_parameters["Steps"]) - except: + except Exception as e: + print(f"batch png info: using ui set prompts; failed to get png info for {image}") + print(e) p.prompt = prompt p.negative_prompt = negative_prompt p.seed = seed p.cfg_scale = cfg_scale p.sampler_name = sampler_name p.steps = steps - print(f"batch png info: using ui set prompts; failed to get png info for {image}") proc = modules.scripts.scripts_img2img.run(p, *args) if proc is None: diff --git a/modules/ui.py b/modules/ui.py index a79b0e6c..d9b21534 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -948,7 +948,7 @@ def create_ui(): inpaint_full_res_padding, inpainting_mask_invert, img2img_batch_use_png_info, - img2img_batch_png_info_props, + img2img_batch_png_info_props, img2img_batch_png_info_dir, img2img_batch_input_dir, img2img_batch_output_dir, From dd268c48c9099c4cf308eb04590bd201c9b64253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=ADn=20=28Netux=29=20Rodr=C3=ADguez?= Date: Sun, 25 Jun 2023 00:30:08 -0300 Subject: [PATCH 040/178] feat(extensions): add toggle all checkbox to Installed tab Small QoL addition. While there is the option to disable all extensions with the radio buttons at the top, that only acts as an added flag and doesn't really change the state of the extensions in the UI. An use case for this checkbox is to disable all extensions except for a few, which is important for debugging extensions. You could do that before, but you'd have to uncheck and recheck every extension one by one. --- javascript/extensions.js | 18 ++++++++++++++++++ modules/ui_extensions.py | 7 +++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/javascript/extensions.js b/javascript/extensions.js index efeaf3a5..1f7254c5 100644 --- a/javascript/extensions.js +++ b/javascript/extensions.js @@ -72,3 +72,21 @@ function config_state_confirm_restore(_, config_state_name, config_restore_type) } return [confirmed, config_state_name, config_restore_type]; } + +function toggle_all_extensions(event) { + gradioApp().querySelectorAll('#extensions .extension_toggle').forEach(function(checkbox_el) { + checkbox_el.checked = event.target.checked; + }); +} + +function toggle_extension() { + let all_extensions_toggled = true; + for (const checkbox_el of gradioApp().querySelectorAll('#extensions .extension_toggle')) { + if (!checkbox_el.checked) { + all_extensions_toggled = false; + break; + } + } + + gradioApp().querySelector('#extensions .all_extensions_toggle').checked = all_extensions_toggled; +} diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 4379a641..50955fab 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -138,7 +138,10 @@ def extension_table(): - + @@ -170,7 +173,7 @@ def extension_table(): code += f""" - + From 9bb1fcfad43103778406ace17e6804c67fad9c17 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 27 Jun 2023 08:59:35 +0300 Subject: [PATCH 041/178] alternate fix for catch errors when retrieving extension index #11290 --- modules/ui_extensions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index f3db76f2..278bf5e4 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -571,9 +571,9 @@ def create_ui(): available_extensions_table = gr.HTML() refresh_available_extensions_button.click( - fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]), + fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]), inputs=[available_extensions_index, hide_tags, sort_column], - outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text], + outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result], ) install_extension_button.click( From 24129368f1b732be25ef486edb2cf5a6ace66737 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 27 Jun 2023 09:19:04 +0300 Subject: [PATCH 042/178] send tensors to the correct device when loading from safetensors file with memmap disabled for #11260 --- modules/sd_models.py | 4 +++- modules/shared.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 0391398a..f65f4e36 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -246,11 +246,13 @@ def read_metadata_from_safetensors(filename): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": + device = map_location or shared.weight_load_location or devices.get_optimal_device_name() + if not shared.opts.disable_mmap_load_safetensors: - device = map_location or shared.weight_load_location or devices.get_optimal_device_name() pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read()) + pl_sd = {k: v.to(device) for k, v in pl_sd.items()} else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) diff --git a/modules/shared.py b/modules/shared.py index 4743a428..203ee1b9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -376,7 +376,7 @@ options_templates.update(options_section(('system', "System"), { "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), - "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files (fixes very slow loading speed in some cases)."), + "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), })) options_templates.update(options_section(('training', "Training"), { From d06af4e517865277d0521642c2c5513af9afd76f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 27 Jun 2023 09:26:18 +0300 Subject: [PATCH 043/178] fix and rework #11113 --- modules/api/api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index f96056b6..1d4aeff5 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -522,9 +522,9 @@ class Api: return options def set_config(self, req: Dict[str, Any]): - checkpoint_key="sd_model_checkpoint" - if checkpoint_key in req and str(req[checkpoint_key]) not in checkpoint_alisases: - raise RuntimeError(f"model {v!r} not found") + checkpoint_name = req.get("sd_model_checkpoint", None) + if checkpoint_name is not None and checkpoint_name not in checkpoint_alisases: + raise RuntimeError(f"model {checkpoint_name!r} not found") for k, v in req.items(): shared.opts.set(k, v) From da14f6a6632e67cacaeaac7441344f0848f66114 Mon Sep 17 00:00:00 2001 From: NoCrypt <57245077+NoCrypt@users.noreply.github.com> Date: Wed, 28 Jun 2023 10:16:44 +0700 Subject: [PATCH 044/178] Add options to change colors in grid --- modules/images.py | 13 +++++-------- modules/shared.py | 5 ++++- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/modules/images.py b/modules/images.py index 1906e2ab..320008be 100644 --- a/modules/images.py +++ b/modules/images.py @@ -10,7 +10,7 @@ import re import numpy as np import piexif import piexif.helper -from PIL import Image, ImageFont, ImageDraw, PngImagePlugin +from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin import string import json import hashlib @@ -156,10 +156,10 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0: fontsize -= 1 fnt = get_font(fontsize) - drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center") + drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=ImageColor.getcolor(opts.grid_text_color_active, 'RGB') if line.is_active else ImageColor.getcolor(opts.grid_text_color_inactive, 'RGB'), anchor="mm", align="center") if not line.is_active: - drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4) + drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=ImageColor.getcolor(opts.grid_text_color_inactive, 'RGB'), width=4) draw_y += line.size[1] + line_spacing @@ -168,9 +168,6 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): fnt = get_font(fontsize) - color_active = (0, 0, 0) - color_inactive = (153, 153, 153) - pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4 cols = im.width // width @@ -179,7 +176,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}' assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}' - calc_img = Image.new("RGB", (1, 1), "white") + calc_img = Image.new("RGB", (1, 1), ImageColor.getcolor(opts.grid_background, 'RGB')) calc_d = ImageDraw.Draw(calc_img) for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)): @@ -200,7 +197,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2 - result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white") + result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), ImageColor.getcolor(opts.grid_background, 'RGB')) for row in range(rows): for col in range(cols): diff --git a/modules/shared.py b/modules/shared.py index 203ee1b9..4a83cca4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -413,6 +413,10 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), + "font": OptionInfo("", "Font for image grids that have text"), + "grid_text_color_active": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}), + "grid_text_color_inactive": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}), + "grid_background": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}), })) options_templates.update(options_section(('optimizations', "Optimizations"), { @@ -471,7 +475,6 @@ options_templates.update(options_section(('ui', "User interface"), { "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), - "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"), From b0ec69b360835a901a1aa57df1f7c8c9d55bf31c Mon Sep 17 00:00:00 2001 From: hako-mikan <122196982+hako-mikan@users.noreply.github.com> Date: Wed, 28 Jun 2023 18:37:08 +0900 Subject: [PATCH 045/178] add 'before_hr callback' script callback --- modules/processing.py | 3 +++ modules/scripts.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index 8da73884..35463c37 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1074,6 +1074,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) + if self.scripts is not None: + self.scripts.before_hr(self) + samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) diff --git a/modules/scripts.py b/modules/scripts.py index 99bf836a..6485f398 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -186,6 +186,11 @@ class Script: return f'script_{tabname}{title}_{item_id}' + def before_hr(self, p ,*args): + """ + This function is called before hires fix start. + """ + pass current_basedir = paths.script_path @@ -548,6 +553,15 @@ class ScriptRunner: self.scripts[si].args_to = args_to + def before_hr(self, p): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.before_hr(p, *script_args) + except Exception: + errors.report(f"Error running before_hr: {script.filename}", exc_info=True) + + scripts_txt2img: ScriptRunner = None scripts_img2img: ScriptRunner = None scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None From 24d4475bdb623b321bc3fdf7205ae4f3221b6dd5 Mon Sep 17 00:00:00 2001 From: catalpaaa Date: Wed, 28 Jun 2023 03:15:01 -0700 Subject: [PATCH 046/178] fixing --subpath on newer gradio version --- webui.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/webui.py b/webui.py index 136d036d..02880b85 100644 --- a/webui.py +++ b/webui.py @@ -359,7 +359,11 @@ def api_only(): modules.script_callbacks.app_started_callback(None, app) print(f"Startup time: {startup_timer.summary()}.") - api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) + api.launch( + server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", + port=cmd_opts.port if cmd_opts.port else 7861, + root_path = f"/{cmd_opts.subpath}" + ) def stop_route(request): @@ -403,6 +407,7 @@ def webui(): "docs_url": "/docs", "redoc_url": "/redoc", }, + root_path = f"/{cmd_opts.subpath}", ) if cmd_opts.add_stop_route: app.add_route("/_stop", stop_route, methods=["POST"]) @@ -436,11 +441,6 @@ def webui(): timer.startup_record = startup_timer.dump() print(f"Startup time: {startup_timer.summary()}.") - if cmd_opts.subpath: - redirector = FastAPI() - redirector.get("/") - gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}") - try: while True: server_command = shared.state.wait_for_server_command(timeout=5) From 45ab7475d61fe42b70c37541974c03736cf73189 Mon Sep 17 00:00:00 2001 From: NoCrypt <57245077+NoCrypt@users.noreply.github.com> Date: Wed, 28 Jun 2023 17:55:58 +0700 Subject: [PATCH 047/178] Revision --- modules/images.py | 13 +++++++++---- modules/shared.py | 6 +++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/modules/images.py b/modules/images.py index 320008be..913c3c2f 100644 --- a/modules/images.py +++ b/modules/images.py @@ -139,6 +139,11 @@ class GridAnnotation: def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): + + color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB') + color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB') + color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB') + def wrap(drawing, text, font, line_length): lines = [''] for word in text.split(): @@ -156,10 +161,10 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0: fontsize -= 1 fnt = get_font(fontsize) - drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=ImageColor.getcolor(opts.grid_text_color_active, 'RGB') if line.is_active else ImageColor.getcolor(opts.grid_text_color_inactive, 'RGB'), anchor="mm", align="center") + drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center") if not line.is_active: - drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=ImageColor.getcolor(opts.grid_text_color_inactive, 'RGB'), width=4) + drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4) draw_y += line.size[1] + line_spacing @@ -176,7 +181,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}' assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}' - calc_img = Image.new("RGB", (1, 1), ImageColor.getcolor(opts.grid_background, 'RGB')) + calc_img = Image.new("RGB", (1, 1), color_background) calc_d = ImageDraw.Draw(calc_img) for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)): @@ -197,7 +202,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2 - result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), ImageColor.getcolor(opts.grid_background, 'RGB')) + result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background) for row in range(rows): for col in range(cols): diff --git a/modules/shared.py b/modules/shared.py index 4a83cca4..22e6bd0b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -414,9 +414,9 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), "font": OptionInfo("", "Font for image grids that have text"), - "grid_text_color_active": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}), - "grid_text_color_inactive": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}), - "grid_background": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}), + "grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}), + "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}), + "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}), })) options_templates.update(options_section(('optimizations', "Optimizations"), { From d22eb8a17f8d8c0e8018d9f9c71f7a96108544ee Mon Sep 17 00:00:00 2001 From: NoCrypt <57245077+NoCrypt@users.noreply.github.com> Date: Wed, 28 Jun 2023 17:57:34 +0700 Subject: [PATCH 048/178] Fix lint --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/images.py b/modules/images.py index 913c3c2f..3e6988fc 100644 --- a/modules/images.py +++ b/modules/images.py @@ -139,7 +139,7 @@ class GridAnnotation: def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): - + color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB') color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB') color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB') From f74fb5049506b85a98b02b1c2fd7361e9f751980 Mon Sep 17 00:00:00 2001 From: NoCrypt <57245077+NoCrypt@users.noreply.github.com> Date: Wed, 28 Jun 2023 20:24:57 +0700 Subject: [PATCH 049/178] Move change colors options to Saving images/grids --- modules/shared.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index 22e6bd0b..76d8e221 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -311,6 +311,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"), "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), + "font": OptionInfo("", "Font for image grids that have text"), + "grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}), + "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}), + "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}), "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), @@ -413,10 +417,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), - "font": OptionInfo("", "Font for image grids that have text"), - "grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}), - "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}), - "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}), })) options_templates.update(options_section(('optimizations', "Optimizations"), { From 9c2a7f1e8bafcb59e566bf568fdefe1be95905fe Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 19 Jun 2023 15:37:20 +0900 Subject: [PATCH 050/178] add callback after_extra_networks_activate --- modules/extra_networks.py | 3 +++ modules/scripts.py | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 1f093df2..41799b0a 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -103,6 +103,9 @@ def activate(p, extra_network_data): except Exception as e: errors.display(e, f"activating extra network {extra_network_name}") + if p.scripts is not None: + p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data) + def deactivate(p, extra_network_data): """call deactivate for extra networks in extra_network_data in specified order, then call diff --git a/modules/scripts.py b/modules/scripts.py index 99bf836a..340f1480 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -116,6 +116,21 @@ class Script: pass + def after_extra_networks_activate(self, p, *args, **kwargs): + """ + Calledafter extra networks activation, before conds calculation + allow modification of the network after extra networks activation been applied + won't be call if p.disable_extra_networks + + **kwargs will have those items: + - batch_number - index of current batch, from 0 to number of batches-1 + - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things + - seeds - list of seeds for current batch + - subseeds - list of subseeds for current batch + - extra_network_data - list of ExtraNetworkParams for current stage + """ + pass + def process_batch(self, p, *args, **kwargs): """ Same as process(), but called for every batch. @@ -483,6 +498,14 @@ class ScriptRunner: except Exception: errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True) + def after_extra_networks_activate(self, p, **kwargs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.after_extra_networks_activate(p, *script_args, **kwargs) + except Exception: + errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True) + def process_batch(self, p, **kwargs): for script in self.alwayson_scripts: try: From 0b0767939d4cc0868a10b6c0978f7b2d963dea1a Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Wed, 28 Jun 2023 17:51:27 -0600 Subject: [PATCH 051/178] Correctly remove end parenthesis with ctrl+up/down --- javascript/edit-attention.js | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index ffa73147..8906c892 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -100,11 +100,12 @@ function keyupEditAttention(event) { if (String(weight).length == 1) weight += ".0"; if (closeCharacter == ')' && weight == 1) { - text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5); + var endParenPos = text.substring(selectionEnd).indexOf(')'); + text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1); selectionStart--; selectionEnd--; } else { - text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1); + text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end); } target.focus(); From cc9c1719786de00d4a5bfcf83be4bf2808cf0cb5 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 29 Jun 2023 14:21:28 +0900 Subject: [PATCH 052/178] rename --add-stop-route to --api-server-stop --- .github/workflows/run_tests.yaml | 2 +- modules/api/api.py | 2 +- modules/cmd_args.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 96546011..178c026a 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -42,7 +42,7 @@ jobs: --no-half --disable-opt-split-attention --use-cpu all - --add-stop-route + --api-server-stop 2>&1 | tee output.txt & - name: Run tests run: | diff --git a/modules/api/api.py b/modules/api/api.py index 279c384a..adc633db 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -202,7 +202,7 @@ class Api: self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) - if shared.cmd_opts.add_stop_route: + if shared.cmd_opts.api_server_stop: self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"]) self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"]) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 624dcb4f..278a605e 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -106,4 +106,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') -parser.add_argument('--add-stop-route', action='store_true', help='enable server stop/restart/kill via api') +parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') From 0bc0e652a3d8cde5533af52c4f232c213b9989f0 Mon Sep 17 00:00:00 2001 From: hunshcn Date: Thu, 29 Jun 2023 18:12:55 +0800 Subject: [PATCH 053/178] sync default value of process_focal_crop_entropy_weight between ui and api --- modules/textual_inversion/preprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 0d4c3f84..dbd856bd 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -7,7 +7,7 @@ from modules import paths, shared, images, deepbooru from modules.textual_inversion import autocrop -def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): +def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): try: if process_caption: shared.interrogator.load() From d47324b898d057c0f854b9be891f2483a2b7001f Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 29 Jun 2023 19:25:18 +0900 Subject: [PATCH 054/178] add stars --- modules/ui_extensions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 278bf5e4..ac239d64 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -424,6 +424,7 @@ sort_ordering = [ (False, lambda x: x.get('name', 'z')), (True, lambda x: x.get('name', 'z')), (False, lambda x: 'z'), + (True, lambda x: x.get('stars', 0)), ] @@ -451,6 +452,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" for ext in sorted(extlist, key=sort_function, reverse=sort_reverse): name = ext.get("name", "noname") + stars = int(ext.get("stars", 0)) added = ext.get('added', 'unknown') url = ext.get("url", None) description = ext.get("description", "") @@ -478,7 +480,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" code += f""" - + @@ -562,7 +564,7 @@ def create_ui(): with gr.Row(): hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) - sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index") + sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", "stars"], type="index") with gr.Row(): search_extensions_text = gr.Text(label="Search").style(container=False) From b1c6e39620dd398ae6a2cb1e9236b65a7294cf59 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 29 Jun 2023 19:25:34 +0900 Subject: [PATCH 055/178] starts left --- style.css | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/style.css b/style.css index e1df716f..5073f0f0 100644 --- a/style.css +++ b/style.css @@ -704,11 +704,24 @@ table.popup-table .link{ margin: 0; } -#available_extensions .date_added{ - opacity: 0.85; +#available_extensions .info{ + margin: 0.5em 0; + display: flex; + margin-top: auto; + opacity: 0.80; font-size: 90%; } +#available_extensions .date_added{ + margin-right: auto; + display: inline-block; +} + +#available_extensions .star_count{ + margin-left: auto; + display: inline-block; +} + /* replace original footer with ours */ footer { From 0416a7bfbaecab20a4ae4cd8330faee556bb3d89 Mon Sep 17 00:00:00 2001 From: Akiba Date: Thu, 29 Jun 2023 18:46:52 +0800 Subject: [PATCH 056/178] fix can't get current hash --- modules/launch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 97539e68..0e0dbca4 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -142,7 +142,7 @@ def git_clone(url, dir, name, commithash=None): if commithash is None: return - current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() + current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip() if current_hash == commithash: return From 2ccc832b3333fe520961466aa1f05b24aafdd792 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 29 Jun 2023 22:46:59 +0900 Subject: [PATCH 057/178] add extensions Update Created dates with sorting --- modules/ui_extensions.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index ac239d64..dff522ef 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -424,10 +424,19 @@ sort_ordering = [ (False, lambda x: x.get('name', 'z')), (True, lambda x: x.get('name', 'z')), (False, lambda x: 'z'), + (True, lambda x: x.get('commit_time', '')), + (True, lambda x: x.get('created_at', '')), (True, lambda x: x.get('stars', 0)), ] +def get_date(info: dict, key): + try: + return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d") + except (ValueError, TypeError): + return '' + + def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""): extlist = available_extensions["extensions"] installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} @@ -454,6 +463,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" name = ext.get("name", "noname") stars = int(ext.get("stars", 0)) added = ext.get('added', 'unknown') + update_time = get_date(ext, 'commit_time') + create_time = get_date(ext, 'created_at') url = ext.get("url", None) description = ext.get("description", "") extension_tags = ext.get("tags", []) @@ -480,7 +491,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" code += f""" - + @@ -564,7 +576,7 @@ def create_ui(): with gr.Row(): hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) - sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", "stars"], type="index") + sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index") with gr.Row(): search_extensions_text = gr.Text(label="Search").style(container=False) From 8a07c59baa670f8ed54757f7ac7580b27ecac3dd Mon Sep 17 00:00:00 2001 From: gshawn3 <133769806+gshawn3@users.noreply.github.com> Date: Fri, 30 Jun 2023 03:49:26 -0700 Subject: [PATCH 058/178] fix for #11534: canvas zoom and pan extension hijacking shortcut keys --- .../canvas-zoom-and-pan/javascript/zoom.js | 49 ++++++++++++------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js index 5ebd2073..ed3e52bc 100644 --- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js +++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js @@ -608,23 +608,29 @@ onUiLoaded(async() => { // Handle keydown events function handleKeyDown(event) { - const hotkeyActions = { - [hotkeysConfig.canvas_hotkey_reset]: resetZoom, - [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap, - [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen - }; + // before activating shortcut, ensure user is not actively typing in an input field + if(event.target.nodeName === 'TEXTAREA' || event.target.nodeName === 'INPUT') { + event.preventDefault; + } else { - const action = hotkeyActions[event.code]; - if (action) { - event.preventDefault(); - action(event); - } + const hotkeyActions = { + [hotkeysConfig.canvas_hotkey_reset]: resetZoom, + [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap, + [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen + }; - if ( - isModifierKey(event, hotkeysConfig.canvas_hotkey_zoom) || - isModifierKey(event, hotkeysConfig.canvas_hotkey_adjust) - ) { - event.preventDefault(); + const action = hotkeyActions[event.code]; + if (action) { + event.preventDefault(); + action(event); + } + + if ( + isModifierKey(event, hotkeysConfig.canvas_hotkey_zoom) || + isModifierKey(event, hotkeysConfig.canvas_hotkey_adjust) + ) { + event.preventDefault(); + } } } @@ -687,10 +693,15 @@ onUiLoaded(async() => { // Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element. function handleMoveKeyDown(e) { if (e.code === hotkeysConfig.canvas_hotkey_move) { - if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) { - e.preventDefault(); - document.activeElement.blur(); - isMoving = true; + // before activating shortcut, ensure user is not actively typing in an input field + if(e.target.nodeName === 'TEXTAREA' || e.target.nodeName === 'INPUT') { + event.preventDefault; + } else { + if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) { + e.preventDefault(); + document.activeElement.blur(); + isMoving = true; + } } } } From 7f46f81dd7b517e829395734750f0eb8360675d4 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Sat, 1 Jul 2023 17:20:56 -0600 Subject: [PATCH 059/178] Change default seed_resize to 0 --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 8da73884..9e838aad 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -109,7 +109,7 @@ class StableDiffusionProcessing: cached_uc = [None, None] cached_c = [None, None] - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = 0, seed_resize_from_w: int = 0, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): if sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) From 74d001bc68c2106aa963e3474eee70327b8f3760 Mon Sep 17 00:00:00 2001 From: ramyma Date: Sun, 2 Jul 2023 04:59:59 +0300 Subject: [PATCH 060/178] Hotfix: call processing close to cleanup API generation calls --- modules/api/api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/api/api.py b/modules/api/api.py index 279c384a..f10e3fe3 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -335,6 +335,7 @@ class Api: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() + p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -392,6 +393,7 @@ class Api: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() + p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] From 8519d52ef505204be53c68e58edc6569ca5cfb32 Mon Sep 17 00:00:00 2001 From: Danil Boldyrev Date: Sun, 2 Jul 2023 19:20:49 +0300 Subject: [PATCH 061/178] fixing the copy/paste function, correct code --- .../canvas-zoom-and-pan/javascript/zoom.js | 67 +++++++++++-------- 1 file changed, 39 insertions(+), 28 deletions(-) diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js index ed3e52bc..29f43a3f 100644 --- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js +++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js @@ -608,29 +608,34 @@ onUiLoaded(async() => { // Handle keydown events function handleKeyDown(event) { + // Disable key locks to make pasting from the buffer work correctly + if ((event.ctrlKey && event.code === 'KeyV') || event.code === "F5") { + return; + } + // before activating shortcut, ensure user is not actively typing in an input field - if(event.target.nodeName === 'TEXTAREA' || event.target.nodeName === 'INPUT') { - event.preventDefault; - } else { + if (event.target.nodeName === 'TEXTAREA' || event.target.nodeName === 'INPUT') { + return; + } - const hotkeyActions = { - [hotkeysConfig.canvas_hotkey_reset]: resetZoom, - [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap, - [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen - }; - const action = hotkeyActions[event.code]; - if (action) { - event.preventDefault(); - action(event); - } + const hotkeyActions = { + [hotkeysConfig.canvas_hotkey_reset]: resetZoom, + [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap, + [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen + }; - if ( - isModifierKey(event, hotkeysConfig.canvas_hotkey_zoom) || - isModifierKey(event, hotkeysConfig.canvas_hotkey_adjust) - ) { - event.preventDefault(); - } + const action = hotkeyActions[event.code]; + if (action) { + event.preventDefault(); + action(event); + } + + if ( + isModifierKey(event, hotkeysConfig.canvas_hotkey_zoom) || + isModifierKey(event, hotkeysConfig.canvas_hotkey_adjust) + ) { + event.preventDefault(); } } @@ -692,16 +697,22 @@ onUiLoaded(async() => { // Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element. function handleMoveKeyDown(e) { + + // Disable key locks to make pasting from the buffer work correctly + if ((e.ctrlKey && e.code === 'KeyV') || e.code === "F5") { + return; + } + + // before activating shortcut, ensure user is not actively typing in an input field + if (e.target.nodeName === 'TEXTAREA' || e.target.nodeName === 'INPUT') { + return; + } + if (e.code === hotkeysConfig.canvas_hotkey_move) { - // before activating shortcut, ensure user is not actively typing in an input field - if(e.target.nodeName === 'TEXTAREA' || e.target.nodeName === 'INPUT') { - event.preventDefault; - } else { - if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) { - e.preventDefault(); - document.activeElement.blur(); - isMoving = true; - } + if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) { + e.preventDefault(); + document.activeElement.blur(); + isMoving = true; } } } From 5a32d4fcb195f7ee5be2617d9f776c01fd0437b7 Mon Sep 17 00:00:00 2001 From: onyasumi Date: Mon, 3 Jul 2023 07:15:19 +0000 Subject: [PATCH 062/178] Fix launch script to be runnable from any directory --- webui.sh | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/webui.sh b/webui.sh index 5c8d977c..8a3c6f12 100755 --- a/webui.sh +++ b/webui.sh @@ -4,26 +4,28 @@ # change the variables in webui-user.sh instead # ################################################# +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + # If run from macOS, load defaults from webui-macos-env.sh if [[ "$OSTYPE" == "darwin"* ]]; then - if [[ -f webui-macos-env.sh ]] + if [[ -f "$SCRIPT_DIR"/webui-macos-env.sh ]] then - source ./webui-macos-env.sh + source "$SCRIPT_DIR"/webui-macos-env.sh fi fi # Read variables from webui-user.sh # shellcheck source=/dev/null -if [[ -f webui-user.sh ]] +if [[ -f "$SCRIPT_DIR"/webui-user.sh ]] then - source ./webui-user.sh + source "$SCRIPT_DIR"/webui-user.sh fi # Set defaults # Install directory without trailing slash if [[ -z "${install_dir}" ]] then - install_dir="$(pwd)" + install_dir="$(dirname "$0")" fi # Name of the subdirectory (defaults to stable-diffusion-webui) From e33e2c51753b91d836aabc52f1f8d67d7de05f86 Mon Sep 17 00:00:00 2001 From: Frank Tao <48634762+onyasumi@users.noreply.github.com> Date: Mon, 3 Jul 2023 03:17:27 -0400 Subject: [PATCH 063/178] Update webui.sh --- webui.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.sh b/webui.sh index 8a3c6f12..246381fc 100755 --- a/webui.sh +++ b/webui.sh @@ -25,7 +25,7 @@ fi # Install directory without trailing slash if [[ -z "${install_dir}" ]] then - install_dir="$(dirname "$0")" + install_dir="$SCRIPT_DIR" fi # Name of the subdirectory (defaults to stable-diffusion-webui) From b70001e618d0f0015273e1313cc7ebe3002a4510 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 30 Jun 2023 13:44:58 +0300 Subject: [PATCH 064/178] Add SD_WEBUI_LOG_LEVEL envvar --- webui.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/webui.py b/webui.py index bad29f28..1b44d4ad 100644 --- a/webui.py +++ b/webui.py @@ -18,6 +18,17 @@ from packaging import version import logging +# We can't use cmd_opts for this because it will not have been initialized at this point. +log_level = os.environ.get("SD_WEBUI_LOG_LEVEL") +if log_level: + log_level = getattr(logging, log_level.upper(), None) or logging.INFO + logging.basicConfig( + level=log_level, + format='%(asctime)s %(levelname)s [%(name)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + ) + +logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) from modules import paths, timer, import_hook, errors, devices # noqa: F401 From f44feb6a10aacc6a5ff4c9275fba2546b2858935 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 30 Jun 2023 13:11:31 +0300 Subject: [PATCH 065/178] Add job argument to State.begin() --- modules/api/api.py | 14 +++++++------- modules/call_queue.py | 2 +- modules/extras.py | 3 +-- modules/interrogate.py | 3 +-- modules/postprocessing.py | 3 +-- modules/shared.py | 4 ++-- 6 files changed, 13 insertions(+), 16 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 279c384a..3ea099ad 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -327,7 +327,7 @@ class Api: p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples - shared.state.begin() + shared.state.begin(job="scripts_txt2img") if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here @@ -384,7 +384,7 @@ class Api: p.outpath_grids = opts.outdir_img2img_grids p.outpath_samples = opts.outdir_img2img_samples - shared.state.begin() + shared.state.begin(job="scripts_img2img") if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here @@ -599,7 +599,7 @@ class Api: def create_embedding(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="create_embedding") filename = create_embedding(**args) # create empty embedding sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used shared.state.end() @@ -610,7 +610,7 @@ class Api: def create_hypernetwork(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="create_hypernetwork") filename = create_hypernetwork(**args) # create empty embedding shared.state.end() return models.CreateResponse(info=f"create hypernetwork filename: {filename}") @@ -620,7 +620,7 @@ class Api: def preprocess(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="preprocess") preprocess(**args) # quick operation unless blip/booru interrogation is enabled shared.state.end() return models.PreprocessResponse(info = 'preprocess complete') @@ -636,7 +636,7 @@ class Api: def train_embedding(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="train_embedding") apply_optimizations = shared.opts.training_xattention_optimizations error = None filename = '' @@ -657,7 +657,7 @@ class Api: def train_hypernetwork(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="train_hypernetwork") shared.loaded_hypernetworks = [] apply_optimizations = shared.opts.training_xattention_optimizations error = None diff --git a/modules/call_queue.py b/modules/call_queue.py index 69bf63d2..3b94f8a4 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -30,7 +30,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): id_task = None with queue_lock: - shared.state.begin() + shared.state.begin(job=id_task) progress.start_task(id_task) try: diff --git a/modules/extras.py b/modules/extras.py index 830b53aa..e9c0263e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -73,8 +73,7 @@ def to_half(tensor, enable): def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata): - shared.state.begin() - shared.state.job = 'model-merge' + shared.state.begin(job="model-merge") def fail(message): shared.state.textinfo = message diff --git a/modules/interrogate.py b/modules/interrogate.py index 9b2c5b60..a3ae1dd5 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -184,8 +184,7 @@ class InterrogateModels: def interrogate(self, pil_image): res = "" - shared.state.begin() - shared.state.job = 'interrogate' + shared.state.begin(job="interrogate") try: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 736315e2..544b2f72 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -9,8 +9,7 @@ from modules.shared import opts def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): devices.torch_gc() - shared.state.begin() - shared.state.job = 'extras' + shared.state.begin(job="extras") image_data = [] image_names = [] diff --git a/modules/shared.py b/modules/shared.py index 203ee1b9..7df2879c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -173,7 +173,7 @@ class State: return obj - def begin(self): + def begin(self, job: str = "(unknown)"): self.sampling_step = 0 self.job_count = -1 self.processing_has_refined_job_count = False @@ -187,7 +187,7 @@ class State: self.interrupted = False self.textinfo = None self.time_start = time.time() - + self.job = job devices.torch_gc() def end(self): From e4303443477b9ac3c90ec4dd58a4810f7ac1eabe Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 30 Jun 2023 13:11:49 +0300 Subject: [PATCH 066/178] API: use finally: for state.end() --- modules/api/api.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 3ea099ad..8b79495d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -602,37 +602,35 @@ class Api: shared.state.begin(job="create_embedding") filename = create_embedding(**args) # create empty embedding sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used - shared.state.end() return models.CreateResponse(info=f"create embedding filename: {filename}") except AssertionError as e: - shared.state.end() return models.TrainResponse(info=f"create embedding error: {e}") + finally: + shared.state.end() + def create_hypernetwork(self, args: dict): try: shared.state.begin(job="create_hypernetwork") filename = create_hypernetwork(**args) # create empty embedding - shared.state.end() return models.CreateResponse(info=f"create hypernetwork filename: {filename}") except AssertionError as e: - shared.state.end() return models.TrainResponse(info=f"create hypernetwork error: {e}") + finally: + shared.state.end() def preprocess(self, args: dict): try: shared.state.begin(job="preprocess") preprocess(**args) # quick operation unless blip/booru interrogation is enabled shared.state.end() - return models.PreprocessResponse(info = 'preprocess complete') + return models.PreprocessResponse(info='preprocess complete') except KeyError as e: - shared.state.end() return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}") - except AssertionError as e: - shared.state.end() + except Exception as e: return models.PreprocessResponse(info=f"preprocess error: {e}") - except FileNotFoundError as e: + finally: shared.state.end() - return models.PreprocessResponse(info=f'preprocess error: {e}') def train_embedding(self, args: dict): try: @@ -649,11 +647,11 @@ class Api: finally: if not apply_optimizations: sd_hijack.apply_optimizations() - shared.state.end() return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") - except AssertionError as msg: - shared.state.end() + except Exception as msg: return models.TrainResponse(info=f"train embedding error: {msg}") + finally: + shared.state.end() def train_hypernetwork(self, args: dict): try: @@ -675,9 +673,10 @@ class Api: sd_hijack.apply_optimizations() shared.state.end() return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") - except AssertionError: + except Exception as exc: + return models.TrainResponse(info=f"train embedding error: {exc}") + finally: shared.state.end() - return models.TrainResponse(info=f"train embedding error: {error}") def get_memory(self): try: From 522a8b9f629940a205812b5b023f25c051f3c8d8 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 30 Jun 2023 13:24:17 +0300 Subject: [PATCH 067/178] Add a status logger in modules.shared --- modules/shared.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/modules/shared.py b/modules/shared.py index 7df2879c..9ab9d98b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -4,6 +4,7 @@ import os import sys import threading import time +import logging import gradio as gr import torch @@ -18,6 +19,8 @@ from modules.paths_internal import models_path, script_path, data_path, sd_confi from ldm.models.diffusion.ddpm import LatentDiffusion from typing import Optional +log = logging.getLogger(__name__) + demo = None parser = cmd_args.parser @@ -144,12 +147,15 @@ class State: def request_restart(self) -> None: self.interrupt() self.server_command = "restart" + log.info("Received restart request") def skip(self): self.skipped = True + log.info("Received skip request") def interrupt(self): self.interrupted = True + log.info("Received interrupt request") def nextjob(self): if opts.live_previews_enable and opts.show_progress_every_n_steps == -1: @@ -189,8 +195,11 @@ class State: self.time_start = time.time() self.job = job devices.torch_gc() + log.info("Starting job %s", job) def end(self): + duration = time.time() - self.time_start + log.info("Ending job %s (%.2f seconds)", self.job, duration) self.job = "" self.job_count = 0 From 08f9b705cda4277aed49ed00c405ada2925e3b50 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 3 Jul 2023 13:08:28 +0300 Subject: [PATCH 068/178] Use read_info_from_image in postprocessing Avoids bad keys such as `exif` ending up in the "PNG info" passed forward --- modules/postprocessing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 736315e2..38544c38 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -54,7 +54,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, for image, name in zip(image_data, image_names): shared.state.textinfo = name - existing_pnginfo = image.info or {} + parameters, existing_pnginfo = images.read_info_from_image(image) + if parameters: + existing_pnginfo["parameters"] = parameters pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB")) From b2c574891f492d00e310e387a024638a7bcf2353 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 3 Jul 2023 13:09:37 +0300 Subject: [PATCH 069/178] read_info_from_image: add `photoshop` to ignored --- modules/images.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/modules/images.py b/modules/images.py index 1906e2ab..ac53a3c5 100644 --- a/modules/images.py +++ b/modules/images.py @@ -662,6 +662,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i return fullfn, txt_fullfn +IGNORED_INFO_KEYS = { + 'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif', + 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression', + 'icc_profile', 'chromaticity', 'photoshop', +} + + def read_info_from_image(image): items = image.info or {} @@ -679,9 +686,7 @@ def read_info_from_image(image): items['exif comment'] = exif_comment geninfo = exif_comment - for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif', - 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression', - 'icc_profile', 'chromaticity']: + for field in IGNORED_INFO_KEYS: items.pop(field, None) if items.get("Software", None) == "NovelAI": From 96f0593c8fcfb5d31da9731d995c6d6f2ad77829 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 3 Jul 2023 13:10:20 +0300 Subject: [PATCH 070/178] read_info_from_image: add type --- modules/images.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/images.py b/modules/images.py index ac53a3c5..74a10a7b 100644 --- a/modules/images.py +++ b/modules/images.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import pytz @@ -669,7 +671,7 @@ IGNORED_INFO_KEYS = { } -def read_info_from_image(image): +def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]: items = image.info or {} geninfo = items.pop('parameters', None) From 5c6a33b3e11f5aa7b2fc56753c5a724e1351ce81 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 3 Jul 2023 13:10:42 +0300 Subject: [PATCH 071/178] read_info_from_image: don't mutate info in passed-in image --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/images.py b/modules/images.py index 74a10a7b..ec421993 100644 --- a/modules/images.py +++ b/modules/images.py @@ -672,7 +672,7 @@ IGNORED_INFO_KEYS = { def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]: - items = image.info or {} + items = (image.info or {}).copy() geninfo = items.pop('parameters', None) From 32788873176e9d79e1fffd6f89f94b6d0ec8bb91 Mon Sep 17 00:00:00 2001 From: ramyma Date: Mon, 3 Jul 2023 20:02:30 +0300 Subject: [PATCH 072/178] Handle cleanup in case there's an exception thrown --- modules/api/api.py | 55 +++++++++++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index f10e3fe3..d9278e9e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -323,19 +323,21 @@ class Api: with self.queue_lock: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) - p.scripts = script_runner - p.outpath_grids = opts.outdir_txt2img_grids - p.outpath_samples = opts.outdir_txt2img_samples + try: + p.scripts = script_runner + p.outpath_grids = opts.outdir_txt2img_grids + p.outpath_samples = opts.outdir_txt2img_samples - shared.state.begin() - if selectable_scripts is not None: - p.script_args = script_args - processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here - else: - p.script_args = tuple(script_args) # Need to pass args as tuple here - processed = process_images(p) - shared.state.end() - p.close() + shared.state.begin() + if selectable_scripts is not None: + p.script_args = script_args + processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here + else: + p.script_args = tuple(script_args) # Need to pass args as tuple here + processed = process_images(p) + shared.state.end() + finally: + p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -380,20 +382,23 @@ class Api: with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) - p.init_images = [decode_base64_to_image(x) for x in init_images] - p.scripts = script_runner - p.outpath_grids = opts.outdir_img2img_grids - p.outpath_samples = opts.outdir_img2img_samples + try: + p.init_images = [decode_base64_to_image(x) for x in init_images] + p.scripts = script_runner + p.outpath_grids = opts.outdir_img2img_grids + p.outpath_samples = opts.outdir_img2img_samples - shared.state.begin() - if selectable_scripts is not None: - p.script_args = script_args - processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here - else: - p.script_args = tuple(script_args) # Need to pass args as tuple here - processed = process_images(p) - shared.state.end() - p.close() + shared.state.begin() + if selectable_scripts is not None: + p.script_args = script_args + processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here + else: + p.script_args = tuple(script_args) # Need to pass args as tuple here + processed = process_images(p) + shared.state.end() + + finally: + p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] From c1c04928596f69ddb39b8841a8435ecefb0594e9 Mon Sep 17 00:00:00 2001 From: ramyma Date: Mon, 3 Jul 2023 20:17:47 +0300 Subject: [PATCH 073/178] Use contextlib for closing the generation process --- modules/api/api.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index d9278e9e..e92c2938 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -30,6 +30,7 @@ from modules import devices from typing import Dict, List, Any import piexif import piexif.helper +from contextlib import closing def script_name_to_index(name, scripts): @@ -322,8 +323,7 @@ class Api: args.pop('save_images', None) with self.queue_lock: - p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) - try: + with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p: p.scripts = script_runner p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples @@ -336,8 +336,6 @@ class Api: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() - finally: - p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -381,8 +379,7 @@ class Api: args.pop('save_images', None) with self.queue_lock: - p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) - try: + with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p: p.init_images = [decode_base64_to_image(x) for x in init_images] p.scripts = script_runner p.outpath_grids = opts.outdir_img2img_grids @@ -397,8 +394,6 @@ class Api: processed = process_images(p) shared.state.end() - finally: - p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] From f731a728c68035ee36317ed0096ac5ecbfd50553 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Mon, 3 Jul 2023 11:41:10 -0600 Subject: [PATCH 074/178] Check seed_resize_from <= 0 --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 9e838aad..dc552121 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -109,7 +109,7 @@ class StableDiffusionProcessing: cached_uc = [None, None] cached_c = [None, None] - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = 0, seed_resize_from_w: int = 0, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): if sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) @@ -573,7 +573,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), - "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), + "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Clip skip": None if clip_skip <= 1 else clip_skip, From f325783abd828c3b90b4d0aa19031401c0ba4c4c Mon Sep 17 00:00:00 2001 From: Danil Boldyrev Date: Tue, 4 Jul 2023 22:26:43 +0300 Subject: [PATCH 075/178] made the blur function optional, added exclusion buttons --- .../canvas-zoom-and-pan/javascript/zoom.js | 20 ++++++++++++------- .../scripts/hotkey_config.py | 1 + 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js index 29f43a3f..30199dcd 100644 --- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js +++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js @@ -200,7 +200,8 @@ onUiLoaded(async() => { canvas_hotkey_move: "KeyF", canvas_hotkey_overlap: "KeyO", canvas_disabled_functions: [], - canvas_show_tooltip: true + canvas_show_tooltip: true, + canvas_blur_prompt: false }; const functionMap = { @@ -609,13 +610,15 @@ onUiLoaded(async() => { // Handle keydown events function handleKeyDown(event) { // Disable key locks to make pasting from the buffer work correctly - if ((event.ctrlKey && event.code === 'KeyV') || event.code === "F5") { + if ((event.ctrlKey && event.code === 'KeyV') || (event.ctrlKey && event.code === 'KeyC') || event.code === "F5") { return; } // before activating shortcut, ensure user is not actively typing in an input field - if (event.target.nodeName === 'TEXTAREA' || event.target.nodeName === 'INPUT') { - return; + if (!hotkeysConfig.canvas_blur_prompt) { + if (event.target.nodeName === 'TEXTAREA' || event.target.nodeName === 'INPUT') { + return; + } } @@ -699,15 +702,18 @@ onUiLoaded(async() => { function handleMoveKeyDown(e) { // Disable key locks to make pasting from the buffer work correctly - if ((e.ctrlKey && e.code === 'KeyV') || e.code === "F5") { + if ((e.ctrlKey && e.code === 'KeyV') || (e.ctrlKey && event.code === 'KeyC') || e.code === "F5") { return; } // before activating shortcut, ensure user is not actively typing in an input field - if (e.target.nodeName === 'TEXTAREA' || e.target.nodeName === 'INPUT') { - return; + if (!hotkeysConfig.canvas_blur_prompt) { + if (e.target.nodeName === 'TEXTAREA' || e.target.nodeName === 'INPUT') { + return; + } } + if (e.code === hotkeysConfig.canvas_hotkey_move) { if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) { e.preventDefault(); diff --git a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py index 1b6683aa..380176ce 100644 --- a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py +++ b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py @@ -9,5 +9,6 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), "canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"), "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"), + "canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"), "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}), })) From c602471b85d270e8c36707817d9bad92b0ff991e Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Wed, 5 Jul 2023 03:19:26 -0600 Subject: [PATCH 076/178] Allow gif for extra network previews --- modules/ui_extra_networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index a7d3bc79..1efd00b0 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -30,8 +30,8 @@ def fetch_file(filename: str = ""): raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") ext = os.path.splitext(filename)[1].lower() - if ext not in (".png", ".jpg", ".jpeg", ".webp"): - raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.") + if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"): + raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.") # would profit from returning 304 return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) From fb661e089f24a3056b9724c580e3badc214467cc Mon Sep 17 00:00:00 2001 From: semjon00 Date: Wed, 5 Jul 2023 15:39:04 +0300 Subject: [PATCH 077/178] Fix throwing exception when trying to resize image with I;16 mode --- modules/images.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/modules/images.py b/modules/images.py index 1906e2ab..91e3fae2 100644 --- a/modules/images.py +++ b/modules/images.py @@ -639,12 +639,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i oversize = image.width > opts.target_side_length or image.height > opts.target_side_length if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024): ratio = image.width / image.height - + resize_to = None if oversize and ratio > 1: - image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS) + resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width) elif oversize: - image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS) + resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length) + if resize_to is not None: + try: + # Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16 + image = image.resize(resize_to, LANCZOS) + except: + image = image.resize(resize_to) try: _atomically_save_image(image, fullfn_without_extension, ".jpg") except Exception as e: From daf41a273485e865c9c9ef458b2c26be4422bcb2 Mon Sep 17 00:00:00 2001 From: Hao-Wu Date: Thu, 6 Jul 2023 15:37:10 +0800 Subject: [PATCH 078/178] Fix warning of 'has_mps' is deprecated from PyTorch --- modules/mac_specific.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/modules/mac_specific.py b/modules/mac_specific.py index d74c6b95..735847f5 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -4,16 +4,21 @@ from modules.sd_hijack_utils import CondFunc from packaging import version -# has_mps is only available in nightly pytorch (for now) and macOS 12.3+. -# check `getattr` and try it for compatibility +# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, +# use check `getattr` and try it for compatibility. +# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty, +# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279 def check_for_mps() -> bool: - if not getattr(torch, 'has_mps', False): - return False - try: - torch.zeros(1).to(torch.device("mps")) - return True - except Exception: - return False + if version.parse(torch.__version__) <= version.parse("2.0.1"): + if not getattr(torch, 'has_mps', False): + return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False + else: + return torch.backends.mps.is_available() and torch.backends.mps.is_built() has_mps = check_for_mps() From 259967b7c60cbd2aeb091e691b5f49d9fb64b872 Mon Sep 17 00:00:00 2001 From: jovijovi Date: Thu, 6 Jul 2023 18:43:17 +0800 Subject: [PATCH 079/178] fix(api): convert to "RGB" if image mode is "RGBA" --- modules/api/api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/api/api.py b/modules/api/api.py index 2e49526e..6507f641 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -84,6 +84,8 @@ def encode_pil_to_base64(image): image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): + if image.mode == "RGBA": + image = image.convert("RGB") parameters = image.info.get('parameters', None) exif_bytes = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } From c258dd34a888b7c6c9e4c9bbef76732d9d7db6e7 Mon Sep 17 00:00:00 2001 From: Neil Mahseth Date: Thu, 6 Jul 2023 22:02:47 +0530 Subject: [PATCH 080/178] Fix UnicodeEncodeError when writing to file CLIP Interrogator Batch Mode The code snippet print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a')) raises a UnicodeEncodeError with the message "'charmap' codec can't encode character '\u016b' in position 129". This error occurs because the default encoding used by the open() function cannot handle certain Unicode characters. To fix this issue, the encoding parameter needs to be explicitly specified when opening the file. By using an appropriate encoding, such as 'utf-8', we can ensure that Unicode characters are properly encoded and written to the file. The updated code should be modified as follows: python Copy code print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8')) By making this change, the code will no longer raise the UnicodeEncodeError and will correctly handle Unicode characters during the file write operation. --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index e2e3b6da..10e35ec3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -155,7 +155,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di img = Image.open(image) filename = os.path.basename(image) left, _ = os.path.splitext(filename) - print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a')) + print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8')) return [gr.update(), None] From f4391796416a998903fcd3e3e0dc7e8cca3614f2 Mon Sep 17 00:00:00 2001 From: gitama2023 <138025603+gitama2023@users.noreply.github.com> Date: Fri, 7 Jul 2023 16:18:01 +0800 Subject: [PATCH 081/178] Added a prompt for users using poor scaling Added a JavaScript file that detects browser scaling and prompts users when scale is not 100% --- javascript/badScaleChecker.js | 108 ++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 javascript/badScaleChecker.js diff --git a/javascript/badScaleChecker.js b/javascript/badScaleChecker.js new file mode 100644 index 00000000..625ad309 --- /dev/null +++ b/javascript/badScaleChecker.js @@ -0,0 +1,108 @@ +(function() { + var ignore = localStorage.getItem("bad-scale-ignore-it") == "ignore-it"; + + function getScale() { + var ratio = 0, + screen = window.screen, + ua = navigator.userAgent.toLowerCase(); + + if (window.devicePixelRatio !== undefined) { + ratio = window.devicePixelRatio; + } else if (~ua.indexOf('msie')) { + if (screen.deviceXDPI && screen.logicalXDPI) { + ratio = screen.deviceXDPI / screen.logicalXDPI; + } + } else if (window.outerWidth !== undefined && window.innerWidth !== undefined) { + ratio = window.outerWidth / window.innerWidth; + } + + return ratio == 0 ? 0 : Math.round(ratio * 100); + } + + var showing = false; + + var div = document.createElement("div"); + div.style.position = "fixed"; + div.style.top = "0px"; + div.style.left = "0px"; + div.style.width = "100vw"; + div.style.backgroundColor = "firebrick"; + div.style.textAlign = "center"; + div.style.zIndex = 99; + + var b = document.createElement("b"); + b.innerHTML = 'Bad Scale: ??% '; + + div.appendChild(b); + + var note1 = document.createElement("p"); + note1.innerHTML = "Change your browser or your computer settings!"; + note1.title = 'Just make sure "computer-scale" * "browser-scale" = 100% ,\n' + + "you can keep your computer-scale and only change this page's scale,\n" + + "for example: your computer-scale is 125%, just use [\"CTRL\"+\"-\"] to make your browser-scale of this page to 80%."; + div.appendChild(note1); + + var note2 = document.createElement("p"); + note2.innerHTML = " Otherwise, it will cause this page to not function properly!"; + note2.title = "When you click \"Copy image to: [inpaint sketch]\" in some img2img's tab,\n" + + "if scale<100% the canvas will be invisible,\n" + + "else if scale>100% this page will take large amount of memory and CPU performance."; + div.appendChild(note2); + + var btn = document.createElement("button"); + btn.innerHTML = "Click here to ignore"; + + div.appendChild(btn); + + function tryShowTopBar(scale) { + if (showing) return; + + b.innerHTML = 'Bad Scale: ' + scale + '% '; + + var updateScaleTimer = setInterval(function() { + var newScale = getScale(); + b.innerHTML = 'Bad Scale: ' + newScale + '% '; + if (newScale == 100) { + var p = div.parentNode; + if (p != null) p.removeChild(div); + showing = false; + clearInterval(updateScaleTimer); + check(); + } + }, 999); + + btn.onclick = function() { + clearInterval(updateScaleTimer); + var p = div.parentNode; + if (p != null) p.removeChild(div); + ignore = true; + showing = false; + localStorage.setItem("bad-scale-ignore-it", "ignore-it"); + }; + + document.body.appendChild(div); + } + + function check() { + if (!ignore) { + var timer = setInterval(function() { + var scale = getScale(); + if (scale != 100 && !ignore) { + tryShowTopBar(scale); + clearInterval(timer); + } + if (ignore) { + clearInterval(timer); + } + }, 999); + } + } + + if (document.readyState != "complete") { + document.onreadystatechange = function() { + if (document.readyState != "complete") check(); + }; + } else { + check(); + } +})(); From a369a0cf658c4371a3b037ded40e22323b6ebce0 Mon Sep 17 00:00:00 2001 From: Nelson Chen Date: Fri, 7 Jul 2023 09:04:49 -0700 Subject: [PATCH 082/178] Add a link to an index-able/crawl-able wiki mirroring service of the wiki At the moment, the wiki is editable by GitHub users, so it is blocked from indexing. If you are searching for something related to this repo, Google and other search engines will not use the content for it. This link hack just sticks a link on the README so search engines may prioritize it. At the moment, 0 pages from GitHub are index and only 7 pages from my proxy service are. If you add this, the rest should get indexed. An indexable page looks like this: https://github-wiki-see.page/m/AUTOMATIC1111/stable-diffusion-webui/wiki/Command-Line-Arguments-and-Settings. It is not meant to be read, just indexed, and users are expected to click through to the GitHub copy. https://github-wiki-see.page/ has more information about the situation. I built the tool and I'm happy to answer any questions I can. Similar: https://github.com/MiSTer-devel/Main_MiSTer#main_mister-main-binary-and-wiki-repo:~:text=For%20the%20purposes%20of%20getting%20google%20to%20crawl%20the%20wiki%2C%20here%27s%20a%20link%20to%20the%20(not%20for%20humans)%20crawlable%20wiki --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 73d94960..e6d8e4bd 100644 --- a/README.md +++ b/README.md @@ -135,8 +135,11 @@ Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-w Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) ## Documentation + The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki). +For the purposes of getting Google and other search engines to crawl the wiki, here's a link to the (not for humans) [crawlable wiki](https://github-wiki-see.page/m/AUTOMATIC1111/stable-diffusion-webui/wiki). + ## Credits Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file. From 19772c3c97647bdda76cd7f652ae517840431e88 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 13:43:42 +0300 Subject: [PATCH 083/178] fix problem with extra network saving images as previews losing generation info add a description for save_image_with_geninfo --- modules/images.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/modules/images.py b/modules/images.py index 1906e2ab..04f55f14 100644 --- a/modules/images.py +++ b/modules/images.py @@ -497,13 +497,23 @@ def get_next_sequence_number(path, basename): return result + 1 -def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None): +def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'): + """ + Saves image to filename, including geninfo as text information for generation info. + For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key. + For JPG images, there's no dictionary and geninfo just replaces the EXIF description. + """ + if extension is None: extension = os.path.splitext(filename)[1] image_format = Image.registered_extensions()[extension] if extension.lower() == '.png': + existing_pnginfo = existing_pnginfo or {} + if opts.enable_pnginfo: + existing_pnginfo[pnginfo_section_name] = geninfo + if opts.enable_pnginfo: pnginfo_data = PngImagePlugin.PngInfo() for k, v in (existing_pnginfo or {}).items(): @@ -622,7 +632,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i """ temp_file_path = f"{filename_without_extension}.tmp" - save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo) + save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name) os.replace(temp_file_path, filename_without_extension + extension) From 7a7fa25d02d469533dab5084bbd08d96d2df45a2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 14:21:40 +0300 Subject: [PATCH 084/178] lint fix for #11492 --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/images.py b/modules/images.py index fdccec09..9a5d9585 100644 --- a/modules/images.py +++ b/modules/images.py @@ -661,7 +661,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i try: # Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16 image = image.resize(resize_to, LANCZOS) - except: + except Exception: image = image.resize(resize_to) try: _atomically_save_image(image, fullfn_without_extension, ".jpg") From 3602602260abaa325850e4768b7e253834e207d0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 14:44:02 +0300 Subject: [PATCH 085/178] whitespace for #11477 --- modules/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/scripts.py b/modules/scripts.py index d96f88b0..a07adc42 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -187,7 +187,7 @@ class Script: return f'script_{tabname}{title}_{item_id}' - def before_hr(self, p ,*args): + def before_hr(self, p, *args): """ This function is called before hires fix start. """ From 18256c5f0174126cb103afece2b39b6b831e034a Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 14:58:33 +0300 Subject: [PATCH 086/178] fix for #11478 --- webui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/webui.py b/webui.py index b02ae7a5..34c2fd18 100644 --- a/webui.py +++ b/webui.py @@ -43,7 +43,7 @@ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvisi startup_timer.record("import torch") -import gradio +import gradio # noqa: F401 startup_timer.record("import gradio") import ldm.modules.encoders.modules # noqa: F401 @@ -413,7 +413,7 @@ def webui(): "docs_url": "/docs", "redoc_url": "/redoc", }, - root_path = f"/{cmd_opts.subpath}", + root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "", ) # after initial launch, disable --autolaunch for subsequent restarts From b88645d9ebddfa26aaf6ee25519a95c967a23138 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 15:14:14 +0300 Subject: [PATCH 087/178] additional changes for merge conflict for #11337 --- modules/img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/img2img.py b/modules/img2img.py index b8ea3a3c..5e18bab9 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -230,7 +230,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" - process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir) + process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_input_dir, png_info_props=img2img_batch_output_dir, png_info_dir=img2img_batch_inpaint_mask_dir) processed = Processed(p, [], p.seed, "") else: From 9043b91649f35adaa732d811184e81afb7a34b71 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 15:14:24 +0300 Subject: [PATCH 088/178] additional changes for merge conflict for #11337 --- modules/ui.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index c752a64d..e83f2651 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -934,13 +934,13 @@ def create_ui(): inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, - img2img_batch_use_png_info, - img2img_batch_png_info_props, - img2img_batch_png_info_dir, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, override_settings, + img2img_batch_use_png_info, + img2img_batch_png_info_props, + img2img_batch_png_info_dir, ] + custom_inputs, outputs=[ img2img_gallery, From 1d71c36de2d7bbbcd290ba4dc5afd8ba909c74f8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 15:21:29 +0300 Subject: [PATCH 089/178] third time's the charm --- modules/img2img.py | 2 +- modules/ui.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 5e18bab9..881212fc 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -230,7 +230,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" - process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_input_dir, png_info_props=img2img_batch_output_dir, png_info_dir=img2img_batch_inpaint_mask_dir) + process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) processed = Processed(p, [], p.seed, "") else: diff --git a/modules/ui.py b/modules/ui.py index e83f2651..39d226ad 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -733,7 +733,7 @@ def create_ui(): img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") - with gr.Accordion("PNG info"): + with gr.Accordion("PNG info", open=False): img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") From 274a3e21babe5fa913b4a34d49b5d7cd72c5fa89 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 15:42:00 +0300 Subject: [PATCH 090/178] small rework for img2img PNG info --- modules/img2img.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 881212fc..a5f1c148 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -96,27 +96,16 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal info_img = Image.open(info_img_path) geninfo, _ = imgutil.read_info_from_image(info_img) parsed_parameters = parse_generation_parameters(geninfo) - if("Prompt" in png_info_props): - p.prompt = prompt + " " + parsed_parameters["Prompt"] - if("Negative prompt" in png_info_props): - p.negative_prompt = negative_prompt + " " + parsed_parameters["Negative prompt"] - if("Seed" in png_info_props): - p.seed = int(parsed_parameters["Seed"]) - if("CFG scale" in png_info_props): - p.cfg_scale = float(parsed_parameters["CFG scale"]) - if("Sampler" in png_info_props): - p.sampler_name = parsed_parameters["Sampler"] - if("Steps" in png_info_props): - p.steps = int(parsed_parameters["Steps"]) - except Exception as e: - print(f"batch png info: using ui set prompts; failed to get png info for {image}") - print(e) - p.prompt = prompt - p.negative_prompt = negative_prompt - p.seed = seed - p.cfg_scale = cfg_scale - p.sampler_name = sampler_name - p.steps = steps + parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})} + except Exception: + parsed_parameters = {} + + p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "") + p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "") + p.seed = int(parsed_parameters.get("Seed", seed)) + p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale)) + p.sampler_name = parsed_parameters.get("Sampler", sampler_name) + p.steps = int(parsed_parameters.get("Steps", steps)) proc = modules.scripts.scripts_img2img.run(p, *args) if proc is None: From 7a6abc59ea1ecd8bb311de1719b018fb5960cd80 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 16:15:28 +0300 Subject: [PATCH 091/178] for #10650: change key to alt+arrows, enable by default --- javascript/edit-order.js | 6 +++++- modules/shared.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/javascript/edit-order.js b/javascript/edit-order.js index 50f7fe37..ad983d33 100644 --- a/javascript/edit-order.js +++ b/javascript/edit-order.js @@ -1,8 +1,12 @@ +/* alt+left/right moves text in prompt */ + function keyupEditOrder(event) { if (!opts.keyedit_move) return; + let target = event.originalTarget || event.composedPath()[0]; if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return; - if (!event.metaKey && !event.ctrlKey) return; + if (!event.altKey) return; + event.preventDefault() let isLeft = event.key == "ArrowLeft"; let isRight = event.key == "ArrowRight"; diff --git a/modules/shared.py b/modules/shared.py index b9c53875..b29c3307 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -494,7 +494,7 @@ options_templates.update(options_section(('ui', "User interface"), { "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"), - "keyedit_move": OptionInfo(False, "Ctrl+left/right moves prompt elements"), + "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(), "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), From d7d6e8cfc8b85a99a48f82975ee213d487783c28 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 16:45:59 +0300 Subject: [PATCH 092/178] use natural sort for shared.walk_files and shared.listfiles, as well as for dirs in extra networks --- extensions-builtin/Lora/lora.py | 2 +- modules/shared.py | 14 +++++++++++--- modules/ui_extra_networks.py | 4 ++-- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 34ff57dd..cd46e6c7 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -443,7 +443,7 @@ def list_available_loras(): os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) - for filename in sorted(candidates, key=str.lower): + for filename in candidates: if os.path.isdir(filename): continue diff --git a/modules/shared.py b/modules/shared.py index b29c3307..48478a68 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -1,6 +1,7 @@ import datetime import json import os +import re import sys import threading import time @@ -832,8 +833,12 @@ mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts) mem_mon.start() +def natural_sort_key(s, regex=re.compile('([0-9]+)')): + return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)] + + def listfiles(dirname): - filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=str.lower) if not x.startswith(".")] + filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")] return [file for file in filenames if os.path.isfile(file)] @@ -858,8 +863,11 @@ def walk_files(path, allowed_extensions=None): if allowed_extensions is not None: allowed_extensions = set(allowed_extensions) - for root, _, files in os.walk(path, followlinks=True): - for filename in files: + items = list(os.walk(path, followlinks=True)) + items = sorted(items, key=lambda x: natural_sort_key(x[0])) + + for root, _, files in items: + for filename in sorted(files, key=natural_sort_key): if allowed_extensions is not None: _, ext = os.path.splitext(filename) if ext not in allowed_extensions: diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 1efd00b0..693cafb6 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -90,8 +90,8 @@ class ExtraNetworksPage: subdirs = {} for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: - for root, dirs, _ in os.walk(parentdir, followlinks=True): - for dirname in dirs: + for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])): + for dirname in sorted(dirs, key=shared.natural_sort_key): x = os.path.join(root, dirname) if not os.path.isdir(x): From e161b5a0259c870b9d01408d02c504c3281dbdb1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 16:54:03 +0300 Subject: [PATCH 093/178] rework #10436 to use shared.walk_files --- modules/img2img.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 3b83814b..ef87eb0f 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -18,13 +18,7 @@ import modules.scripts def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): processing.fix_seed(p) - images = [] - for root, directories, files in os.walk(input_dir): - for filename in files: - filepath = os.path.join(root, filename) - if filepath.endswith(".jpg") or filepath.endswith(".jpeg") or filepath.endswith(".png") or filepath.endswith(".webp"): - images.append(filepath) - + images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp"))) is_inpaint_batch = False if inpaint_mask_dir: From da8916f92649fc4d947cb46d9d8f8ea1621b2a59 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 17:13:18 +0300 Subject: [PATCH 094/178] added torch.mps.empty_cache() to torch_gc() changed a bunch of places that use torch.cuda.empty_cache() to use torch_gc() instead --- extensions-builtin/LDSR/ldsr_model_arch.py | 8 +++----- extensions-builtin/ScuNET/scripts/scunet_model.py | 4 ++-- extensions-builtin/SwinIR/scripts/swinir_model.py | 5 +---- modules/codeformer_model.py | 2 +- modules/devices.py | 3 +++ modules/sd_models.py | 1 - 6 files changed, 10 insertions(+), 13 deletions(-) diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index 7f450086..7cac36ce 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -12,7 +12,7 @@ import safetensors.torch from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import instantiate_from_config, ismap -from modules import shared, sd_hijack +from modules import shared, sd_hijack, devices cached_ldsr_model: torch.nn.Module = None @@ -112,8 +112,7 @@ class LDSR: gc.collect() - if torch.cuda.is_available: - torch.cuda.empty_cache() + devices.torch_gc() im_og = image width_og, height_og = im_og.size @@ -150,8 +149,7 @@ class LDSR: del model gc.collect() - if torch.cuda.is_available: - torch.cuda.empty_cache() + devices.torch_gc() return a diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index ffef26b2..167d2f64 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -85,7 +85,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): def do_upscale(self, img: PIL.Image.Image, selected_file): - torch.cuda.empty_cache() + devices.torch_gc() try: model = self.load_model(selected_file) @@ -110,7 +110,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() del torch_img, torch_output - torch.cuda.empty_cache() + devices.torch_gc() output = np_output.transpose((1, 2, 0)) # CHW to HWC output = output[:, :, ::-1] # BGR to RGB diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index c6bc53a8..c2c2a43c 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -42,10 +42,7 @@ class UpscalerSwinIR(Upscaler): return img model = model.to(device_swinir, dtype=devices.dtype) img = upscale(img, model) - try: - torch.cuda.empty_cache() - except Exception: - pass + devices.torch_gc() return img def load_model(self, path, scale=4): diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index f293acf5..da42b5e9 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -99,7 +99,7 @@ def setup_model(dirname): output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) del output - torch.cuda.empty_cache() + devices.torch_gc() except Exception: errors.report('Failed inference for CodeFormer', exc_info=True) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) diff --git a/modules/devices.py b/modules/devices.py index 620ed1a6..c5ad950f 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -49,10 +49,13 @@ def get_device_for(task): def torch_gc(): + if torch.cuda.is_available(): with torch.cuda.device(get_cuda_device_string()): torch.cuda.empty_cache() torch.cuda.ipc_collect() + elif has_mps() and hasattr(torch.mps, 'empty_cache'): + torch.mps.empty_cache() def enable_tf32(): diff --git a/modules/sd_models.py b/modules/sd_models.py index f65f4e36..653c4cc0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -590,7 +590,6 @@ def unload_model_weights(sd_model=None, info=None): sd_model = None gc.collect() devices.torch_gc() - torch.cuda.empty_cache() print(f"Unloaded weights {timer.summary()}.") From da468a585bb631bc91c3435f349dfb7ce7fe3895 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 3 Jul 2023 12:17:20 +0300 Subject: [PATCH 095/178] Fix typo: checkpoint_alisases --- modules/api/api.py | 4 ++-- modules/processing.py | 2 +- modules/sd_models.py | 11 ++++++----- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 224bbfc6..5793bb44 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -22,7 +22,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_alisases +from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases from modules.sd_vae import vae_dict from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models @@ -519,7 +519,7 @@ class Api: def set_config(self, req: Dict[str, Any]): checkpoint_name = req.get("sd_model_checkpoint", None) - if checkpoint_name is not None and checkpoint_name not in checkpoint_alisases: + if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases: raise RuntimeError(f"model {checkpoint_name!r} not found") for k, v in req.items(): diff --git a/modules/processing.py b/modules/processing.py index 21d1492c..cd568a20 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -606,7 +606,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try: # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint - if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None: + if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None: p.override_settings.pop('sd_model_checkpoint', None) sd_models.reload_model_weights() diff --git a/modules/sd_models.py b/modules/sd_models.py index 653c4cc0..060e0007 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -23,7 +23,8 @@ model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) checkpoints_list = {} -checkpoint_alisases = {} +checkpoint_aliases = {} +checkpoint_alisases = checkpoint_aliases # for compatibility with old name checkpoints_loaded = collections.OrderedDict() @@ -66,7 +67,7 @@ class CheckpointInfo: def register(self): checkpoints_list[self.title] = self for id in self.ids: - checkpoint_alisases[id] = self + checkpoint_aliases[id] = self def calculate_shorthash(self): self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}") @@ -112,7 +113,7 @@ def checkpoint_tiles(): def list_models(): checkpoints_list.clear() - checkpoint_alisases.clear() + checkpoint_aliases.clear() cmd_ckpt = shared.cmd_opts.ckpt if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt): @@ -136,7 +137,7 @@ def list_models(): def get_closet_checkpoint_match(search_string): - checkpoint_info = checkpoint_alisases.get(search_string, None) + checkpoint_info = checkpoint_aliases.get(search_string, None) if checkpoint_info is not None: return checkpoint_info @@ -166,7 +167,7 @@ def select_checkpoint(): """Raises `FileNotFoundError` if no checkpoints are found.""" model_checkpoint = shared.opts.sd_model_checkpoint - checkpoint_info = checkpoint_alisases.get(model_checkpoint, None) + checkpoint_info = checkpoint_aliases.get(model_checkpoint, None) if checkpoint_info is not None: return checkpoint_info From 4da92281f65a5d3620e61aef76dc2ec23394e706 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 17:29:28 +0300 Subject: [PATCH 096/178] pin version for torch for Navi3 according to comment from #11228 --- webui.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.sh b/webui.sh index 4b0d7cd8..a683d946 100755 --- a/webui.sh +++ b/webui.sh @@ -134,7 +134,7 @@ case "$gpu_info" in *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 ;; *"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \ - export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.5" + export TORCH_COMMAND="pip install --pre torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 --index-url https://download.pytorch.org/whl/nightly/rocm5.5" # Navi 3 needs at least 5.5 which is only on the nightly chain ;; *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 From 4981c7d3704e50dd93fe1b68d299239a4ded1ec2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 17:52:03 +0300 Subject: [PATCH 097/178] move github proxy to settings, System page. --- modules/shared.py | 1 + modules/ui_extensions.py | 33 ++++++++++++++------------------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index 48478a68..b7518de6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -391,6 +391,7 @@ options_templates.update(options_section(('system', "System"), { "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), + "github_proxy": OptionInfo("None", "Github proxy", ui_components.DropdownEditable, lambda: {"choices": ["None", "ghproxy.com", "hub.yzuu.cf", "hub.njuu.cf", "hub.nuaa.cf"]}).info("for custom inputs will just replace github.com with the input"), })) options_templates.update(options_section(('training', "Training"), { diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index ac523bcf..a208012d 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -325,7 +325,18 @@ def normalize_git_url(url): return url -def install_extension_from_url(dirname, proxy, url, branch_name=None): +def github_proxy(url): + proxy = shared.opts.github_proxy + + if proxy == 'None': + return url + if proxy == 'ghproxy.com': + return "https://ghproxy.com/" + url + + return url.replace('github.com', proxy) + + +def install_extension_from_url(dirname, url, branch_name=None): check_access() if isinstance(dirname, str): @@ -335,18 +346,7 @@ def install_extension_from_url(dirname, proxy, url, branch_name=None): assert url, 'No URL specified' - proxy_list = { - "none": "", - "ghproxy": "https://ghproxy.com/", - "yzuu": "hub.yzuu.cf", - "njuu": "hub.njuu.cf", - "nuaa": "hub.nuaa.cf", - } - - if proxy in ['yzuu', 'njuu', 'nuaa']: - url = url.replace('github.com', proxy_list[proxy]) - elif proxy == 'ghproxy': - url = proxy_list[proxy] + url + url = github_proxy(url) if dirname is None or dirname == "": *parts, last_part = url.split('/') @@ -628,11 +628,6 @@ def create_ui(): ) with gr.TabItem("Install from URL", id="install_from_url"): - install_proxy = gr.Radio( - label="Install Proxy", choices=["none", "ghproxy", "nuaa", "yzuu", "njuu"], value="none", - info="If you can't access github.com, you can use a proxy to install extensions from github.com" - ) - install_url = gr.Text(label="URL for extension's git repository") install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch") install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") @@ -641,7 +636,7 @@ def create_ui(): install_button.click( fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]), - inputs=[install_dirname, install_proxy, install_url, install_branch], + inputs=[install_dirname, install_url, install_branch], outputs=[install_url, extensions_table, install_result], ) From e3507a1be4826f5c196cb8651d932c9af84a5019 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 17:53:17 +0300 Subject: [PATCH 098/178] fix for eslint --- javascript/edit-order.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/edit-order.js b/javascript/edit-order.js index ad983d33..e6e73937 100644 --- a/javascript/edit-order.js +++ b/javascript/edit-order.js @@ -6,7 +6,7 @@ function keyupEditOrder(event) { let target = event.originalTarget || event.composedPath()[0]; if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return; if (!event.altKey) return; - event.preventDefault() + event.preventDefault(); let isLeft = event.key == "ArrowLeft"; let isRight = event.key == "ArrowRight"; From 44d66daaad3dae283a85329020d1345d08189e32 Mon Sep 17 00:00:00 2001 From: SiYu Wu Date: Sun, 9 Jul 2023 03:05:38 +0800 Subject: [PATCH 099/178] add option SWIN_torch_compile to accelerate SwinIR upscale using torch.compile() --- .../SwinIR/scripts/swinir_model.py | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index c2c2a43c..ae0d0e6a 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,4 +1,5 @@ import sys +import platform import numpy as np import torch @@ -18,6 +19,8 @@ device_swinir = devices.get_device_for('swinir') class UpscalerSwinIR(Upscaler): def __init__(self, dirname): + self._cached_model = None # keep the model when SWIN_torch_compile is on to prevent re-compile every runs + self._cached_model_config = None # to clear '_cached_model' when changing model (v1/v2) or settings self.name = "SwinIR" self.model_url = SWINIR_MODEL_URL self.model_name = "SwinIR 4x" @@ -35,12 +38,24 @@ class UpscalerSwinIR(Upscaler): self.scalers = scalers def do_upscale(self, img, model_file): - try: - model = self.load_model(model_file) - except Exception as e: - print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) - return img - model = model.to(device_swinir, dtype=devices.dtype) + use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \ + and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows" + current_config = (model_file, opts.SWIN_tile) + + if use_compile and self._cached_model_config == current_config: + model = self._cached_model + else: + self._cached_model = None + try: + model = self.load_model(model_file) + except Exception as e: + print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) + return img + model = model.to(device_swinir, dtype=devices.dtype) + if use_compile: + model = torch.compile(model) + self._cached_model = model + self._cached_model_config = current_config img = upscale(img, model) devices.torch_gc() return img @@ -170,6 +185,8 @@ def on_ui_settings(): shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) + if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows + shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run")) script_callbacks.on_ui_settings(on_ui_settings) From 75f56406cede0095eb0e5dcc4e0d5759063e89dc Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sun, 9 Jul 2023 22:40:23 +0800 Subject: [PATCH 100/178] Revert Pull Request #11244 Revert "Add github mirror for the download extension" This reverts commit 9ec2ba2d28bb0d8f01e19e2919b7bf2e3e864773. Revert "Update code style" This reverts commit de022c4c80240a430a8099fb27a41aa505bf5b2f. Revert "Update call method" This reverts commit e9bd18c57bd83363d38c7409263fe87f3ed3a7f0. Revert "move github proxy to settings, System page." This reverts commit 4981c7d3704e50dd93fe1b68d299239a4ded1ec2. --- modules/shared.py | 1 - modules/ui_extensions.py | 17 ++--------------- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index b7518de6..48478a68 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -391,7 +391,6 @@ options_templates.update(options_section(('system', "System"), { "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), - "github_proxy": OptionInfo("None", "Github proxy", ui_components.DropdownEditable, lambda: {"choices": ["None", "ghproxy.com", "hub.yzuu.cf", "hub.njuu.cf", "hub.nuaa.cf"]}).info("for custom inputs will just replace github.com with the input"), })) options_templates.update(options_section(('training', "Training"), { diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index a208012d..dff522ef 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -325,17 +325,6 @@ def normalize_git_url(url): return url -def github_proxy(url): - proxy = shared.opts.github_proxy - - if proxy == 'None': - return url - if proxy == 'ghproxy.com': - return "https://ghproxy.com/" + url - - return url.replace('github.com', proxy) - - def install_extension_from_url(dirname, url, branch_name=None): check_access() @@ -346,8 +335,6 @@ def install_extension_from_url(dirname, url, branch_name=None): assert url, 'No URL specified' - url = github_proxy(url) - if dirname is None or dirname == "": *parts, last_part = url.split('/') last_part = normalize_git_url(last_part) @@ -367,12 +354,12 @@ def install_extension_from_url(dirname, url, branch_name=None): shutil.rmtree(tmpdir, True) if not branch_name: # if no branch is specified, use the default branch - with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], verbose=False) as repo: + with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo: repo.remote().fetch() for submodule in repo.submodules: submodule.update() else: - with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name, verbose=False) as repo: + with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo: repo.remote().fetch() for submodule in repo.submodules: submodule.update() From 44c27ebc7393ea793245aa565ace6c9bf1313980 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 10 Jul 2023 20:08:23 +0300 Subject: [PATCH 101/178] Use closing() with processing classes everywhere Follows up on #11569 --- modules/hypernetworks/hypernetwork.py | 6 ++++-- modules/img2img.py | 20 +++++++++---------- .../textual_inversion/textual_inversion.py | 6 ++++-- modules/txt2img.py | 11 +++++----- 4 files changed, 24 insertions(+), 19 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 51941c11..79670b87 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -3,6 +3,7 @@ import glob import html import os import inspect +from contextlib import closing import modules.textual_inversion.dataset import torch @@ -711,8 +712,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi preview_text = p.prompt - processed = processing.process_images(p) - image = processed.images[0] if len(processed.images) > 0 else None + with closing(p): + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images) > 0 else None if unload: shared.sd_model.cond_stage_model.to(devices.cpu) diff --git a/modules/img2img.py b/modules/img2img.py index ef87eb0f..4d9a02cc 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -1,4 +1,5 @@ import os +from contextlib import closing from pathlib import Path import numpy as np @@ -217,18 +218,17 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if mask: p.extra_generation_params["Mask blur"] = mask_blur - if is_batch: - assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" + with closing(p): + if is_batch: + assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" - process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) + process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) - processed = Processed(p, [], p.seed, "") - else: - processed = modules.scripts.scripts_img2img.run(p, *args) - if processed is None: - processed = process_images(p) - - p.close() + processed = Processed(p, [], p.seed, "") + else: + processed = modules.scripts.scripts_img2img.run(p, *args) + if processed is None: + processed = process_images(p) shared.total_tqdm.clear() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index bb6f211c..cbe975b7 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -1,5 +1,6 @@ import os from collections import namedtuple +from contextlib import closing import torch import tqdm @@ -584,8 +585,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st preview_text = p.prompt - processed = processing.process_images(p) - image = processed.images[0] if len(processed.images) > 0 else None + with closing(p): + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images) > 0 else None if unload: shared.sd_model.first_stage_model.to(devices.cpu) diff --git a/modules/txt2img.py b/modules/txt2img.py index 6aa79f23..d0be2e73 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,3 +1,5 @@ +from contextlib import closing + import modules.scripts from modules import sd_samplers, processing from modules.generation_parameters_copypaste import create_override_settings_dict @@ -53,12 +55,11 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step if cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) - processed = modules.scripts.scripts_txt2img.run(p, *args) + with closing(p): + processed = modules.scripts.scripts_txt2img.run(p, *args) - if processed is None: - processed = processing.process_images(p) - - p.close() + if processed is None: + processed = processing.process_images(p) shared.total_tqdm.clear() From 10d4e4ace2d243c020e6a83060c938dee7d8c02d Mon Sep 17 00:00:00 2001 From: TangJicheng Date: Tue, 11 Jul 2023 17:30:57 +0900 Subject: [PATCH 102/178] add cmd_args: --timeout-keep-alive --- modules/cmd_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index de905caa..982d9055 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -107,3 +107,4 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 ha parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') +parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') From 14501f56aaf3c97fb2c38633350dc747b9651f43 Mon Sep 17 00:00:00 2001 From: TangJicheng Date: Tue, 11 Jul 2023 17:32:04 +0900 Subject: [PATCH 103/178] set timeout_keep_alive --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 7f7e3a9b..4ea5d825 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -715,4 +715,4 @@ class Api: def launch(self, server_name, port): self.app.include_router(self.router) - uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0) + uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive) From b85fc7187d953828340d4e3af34af46d9fc70b9e Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 10 Jul 2023 21:18:34 +0300 Subject: [PATCH 104/178] Fix MPS cache cleanup Importing torch does not import torch.mps so the call failed. --- modules/devices.py | 5 +++-- modules/mac_specific.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index c5ad950f..57e51da3 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -54,8 +54,9 @@ def torch_gc(): with torch.cuda.device(get_cuda_device_string()): torch.cuda.empty_cache() torch.cuda.ipc_collect() - elif has_mps() and hasattr(torch.mps, 'empty_cache'): - torch.mps.empty_cache() + + if has_mps(): + mac_specific.torch_mps_gc() def enable_tf32(): diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 735847f5..2c2f15ca 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -1,8 +1,12 @@ +import logging + import torch import platform from modules.sd_hijack_utils import CondFunc from packaging import version +log = logging.getLogger() + # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, # use check `getattr` and try it for compatibility. @@ -19,9 +23,19 @@ def check_for_mps() -> bool: return False else: return torch.backends.mps.is_available() and torch.backends.mps.is_built() + + has_mps = check_for_mps() +def torch_mps_gc() -> None: + try: + from torch.mps import empty_cache + empty_cache() + except Exception: + log.warning("MPS garbage collection failed", exc_info=True) + + # MPS workaround for https://github.com/pytorch/pytorch/issues/89784 def cumsum_fix(input, cumsum_func, *args, **kwargs): if input.device.type == 'mps': From 3636c2c6eda6ea25db95a5e3e77fe1ac347f0081 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 11 Jul 2023 15:05:20 +0300 Subject: [PATCH 105/178] Allow using alt in the prompt fields again --- javascript/edit-order.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/edit-order.js b/javascript/edit-order.js index e6e73937..ed4ef9ac 100644 --- a/javascript/edit-order.js +++ b/javascript/edit-order.js @@ -6,11 +6,11 @@ function keyupEditOrder(event) { let target = event.originalTarget || event.composedPath()[0]; if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return; if (!event.altKey) return; - event.preventDefault(); let isLeft = event.key == "ArrowLeft"; let isRight = event.key == "ArrowRight"; if (!isLeft && !isRight) return; + event.preventDefault(); let selectionStart = target.selectionStart; let selectionEnd = target.selectionEnd; From af081211ee93622473ee575de30fed2fd8263c09 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 11 Jul 2023 21:16:43 +0300 Subject: [PATCH 106/178] getting SD2.1 to run on SDXL repo --- modules/launch_utils.py | 3 ++ modules/paths.py | 1 + modules/prompt_parser.py | 64 +++++++++++++++++++++++++------ modules/sd_hijack.py | 9 +++++ modules/sd_hijack_open_clip.py | 4 ++ modules/sd_models.py | 8 +++- modules/sd_models_config.py | 2 + modules/sd_models_xl.py | 40 +++++++++++++++++++ modules/sd_samplers_kdiffusion.py | 45 +++++++++++++++++----- 9 files changed, 152 insertions(+), 24 deletions(-) create mode 100644 modules/sd_models_xl.py diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 0e0dbca4..3b740dbd 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -235,11 +235,13 @@ def prepare_environment(): openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") + stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") + stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") @@ -297,6 +299,7 @@ def prepare_environment(): os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) + git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) diff --git a/modules/paths.py b/modules/paths.py index bada804e..f509a85f 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -20,6 +20,7 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), + (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', []), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 0069d8b0..d7f9e9a9 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -144,7 +144,12 @@ def get_learned_conditioning(model, prompts, steps): cond_schedule = [] for i, (end_at_step, _) in enumerate(prompt_schedule): - cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i])) + if isinstance(conds, dict): + cond = {k: v[i] for k, v in conds.items()} + else: + cond = conds[i] + + cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond)) cache[prompt] = cond_schedule res.append(cond_schedule) @@ -214,20 +219,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) +class DictWithShape(dict): + def __init__(self, x, shape): + super().__init__() + self.update(x) + + @property + def shape(self): + return self["crossattn"].shape + + def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step): param = c[0][0].cond - res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) + is_dict = isinstance(param, dict) + + if is_dict: + dict_cond = param + res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()} + res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape) + else: + res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) + for i, cond_schedule in enumerate(c): target_index = 0 for current, entry in enumerate(cond_schedule): if current_step <= entry.end_at_step: target_index = current break - res[i] = cond_schedule[target_index].cond + + if is_dict: + for k, param in cond_schedule[target_index].cond.items(): + res[k][i] = param + else: + res[i] = cond_schedule[target_index].cond return res +def stack_conds(tensors): + # if prompts have wildly different lengths above the limit we'll get tensors of different shapes + # and won't be able to torch.stack them. So this fixes that. + token_count = max([x.shape[0] for x in tensors]) + for i in range(len(tensors)): + if tensors[i].shape[0] != token_count: + last_vector = tensors[i][-1:] + last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) + tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + + return torch.stack(tensors) + + + def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): param = c.batch[0][0].schedules[0].cond @@ -249,16 +291,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): conds_list.append(conds_for_batch) - # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes - # and won't be able to torch.stack them. So this fixes that. - token_count = max([x.shape[0] for x in tensors]) - for i in range(len(tensors)): - if tensors[i].shape[0] != token_count: - last_vector = tensors[i][-1:] - last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) - tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + if isinstance(tensors[0], dict): + keys = list(tensors[0].keys()) + stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys} + stacked = DictWithShape(stacked, stacked['crossattn'].shape) + else: + stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype) - return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype) + return conds_list, stacked re_attention = re.compile(r""" diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3b6f95ce..c4b9211f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -166,6 +166,15 @@ class StableDiffusionModelHijack: undo_optimizations() def hijack(self, m): + conditioner = getattr(m, 'conditioner', None) + if conditioner: + for i in range(len(conditioner.embedders)): + embedder = conditioner.embedders[i] + if type(embedder).__name__ == 'FrozenOpenCLIPEmbedder': + embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) + m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) + conditioner.embedders[i] = m.cond_stage_model + if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py index f733e852..6ac5bda6 100644 --- a/modules/sd_hijack_open_clip.py +++ b/modules/sd_hijack_open_clip.py @@ -16,6 +16,10 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit self.id_end = tokenizer.encoder[""] self.id_pad = 0 + self.is_trainable = getattr(wrapped, 'is_trainable', False) + self.input_key = getattr(wrapped, 'input_key', 'txt') + self.legacy_ucg_val = None + def tokenize(self, texts): assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' diff --git a/modules/sd_models.py b/modules/sd_models.py index 060e0007..8d639583 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd @@ -289,6 +289,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + if hasattr(model, 'conditioner'): + sd_models_xl.extend_sdxl(model) + model.load_state_dict(state_dict, strict=False) del state_dict timer.record("apply weights to model") @@ -334,7 +337,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.sd_checkpoint_info = checkpoint_info shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 - model.logvar = model.logvar.to(devices.device) # fix for training + if hasattr(model, 'logvar'): + model.logvar = model.logvar.to(devices.device) # fix for training sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 9bfe1237..96501569 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -6,11 +6,13 @@ from modules import shared, paths, sd_disable_initialization sd_configs_path = shared.sd_configs_path sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") +sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference") config_default = shared.sd_default_config config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py new file mode 100644 index 00000000..d43b8868 --- /dev/null +++ b/modules/sd_models_xl.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import torch + +import sgm.models.diffusion +import sgm.modules.diffusionmodules.denoiser_scaling +import sgm.modules.diffusionmodules.discretizer +from modules import devices + + +def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: list[str]): + for embedder in self.conditioner.embedders: + embedder.ucg_rate = 0.0 + + c = self.conditioner({'txt': batch}) + + return c + + +def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): + return self.model(x, t, cond) + + +def extend_sdxl(model): + dtype = next(model.model.diffusion_model.parameters()).dtype + model.model.diffusion_model.dtype = dtype + model.model.conditioning_key = 'crossattn' + + model.cond_stage_model = [x for x in model.conditioner.embedders if type(x).__name__ == 'FrozenOpenCLIPEmbedder'][0] + model.cond_stage_key = model.cond_stage_model.input_key + + model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" + + discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() + model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + + +sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning +sgm.models.diffusion.DiffusionEngine.apply_model = apply_model + diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 71581b76..73289ce4 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -53,6 +53,28 @@ k_diffusion_scheduler = { } +def catenate_conds(conds): + if not isinstance(conds[0], dict): + return torch.cat(conds) + + return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} + + +def subscript_cond(cond, a, b): + if not isinstance(cond, dict): + return cond[a:b] + + return {key: vec[a:b] for key, vec in cond.items()} + + +def pad_cond(tensor, repeats, empty): + if not isinstance(tensor, dict): + return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) + + tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) + return tensor + + class CFGDenoiser(torch.nn.Module): """ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) @@ -105,10 +127,13 @@ class CFGDenoiser(torch.nn.Module): if shared.sd_model.model.conditioning_key == "crossattn-adm": image_uncond = torch.zeros_like(image_cond) - make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} + make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} else: image_uncond = image_cond - make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} + if isinstance(uncond, dict): + make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} + else: + make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} if not is_edit_model: x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) @@ -140,28 +165,28 @@ class CFGDenoiser(torch.nn.Module): num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] if num_repeats < 0: - tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1) + tensor = pad_cond(tensor, -num_repeats, empty) self.padded_cond_uncond = True elif num_repeats > 0: - uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1) + uncond = pad_cond(uncond, num_repeats, empty) self.padded_cond_uncond = True if tensor.shape[1] == uncond.shape[1] or skip_uncond: if is_edit_model: - cond_in = torch.cat([tensor, uncond, uncond]) + cond_in = catenate_conds([tensor, uncond, uncond]) elif skip_uncond: cond_in = tensor else: - cond_in = torch.cat([tensor, uncond]) + cond_in = catenate_conds([tensor, uncond]) if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in)) + x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) else: x_out = torch.zeros_like(x_in) for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b])) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(cond_in[a:b], image_cond_in[a:b])) else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size @@ -170,14 +195,14 @@ class CFGDenoiser(torch.nn.Module): b = min(a + batch_size, tensor.shape[0]) if not is_edit_model: - c_crossattn = [tensor[a:b]] + c_crossattn = subscript_cond(tensor, a, b) else: c_crossattn = torch.cat([tensor[a:b]], uncond) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) if not skip_uncond: - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:])) + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) denoised_image_indexes = [x[0][0] for x in conds_list] if skip_uncond: From 3fee3c34f1b01d21770ab0a226b432cdd8444792 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Wed, 12 Jul 2023 02:45:03 -0600 Subject: [PATCH 107/178] Save img2img batch with images.save_image() --- modules/img2img.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 2c497020..15306972 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -8,6 +8,7 @@ from modules import sd_samplers from modules.generation_parameters_copypaste import create_override_settings_dict from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state +from modules.images import save_image import modules.shared as shared import modules.processing as processing from modules.ui import plaintext_to_html @@ -84,17 +85,17 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal proc = process_images(p) for n, processed_image in enumerate(proc.images): - filename = image_path.name + filename = image_path.stem + infotext = proc.infotext(p, n) if n > 0: - left, right = os.path.splitext(filename) - filename = f"{left}-{n}{right}" + filename += f"-{n}" if not save_normally: os.makedirs(output_dir, exist_ok=True) if processed_image.mode == 'RGBA': processed_image = processed_image.convert("RGB") - processed_image.save(os.path.join(output_dir, filename)) + save_image(processed_image, output_dir, None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False) def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): From 6c0d5d1198576dbe664f55cffec27b03d0789efd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=80=80=E5=AE=97?= Date: Wed, 12 Jul 2023 16:51:50 +0800 Subject: [PATCH 108/178] fix: check fill size none zero when resize (fixes #11425) --- modules/images.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/modules/images.py b/modules/images.py index 7bbfc3e0..7935b122 100644 --- a/modules/images.py +++ b/modules/images.py @@ -302,12 +302,14 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None): if ratio < src_ratio: fill_height = height // 2 - src_h // 2 - res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) - res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) + if fill_height > 0: + res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) + res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) elif ratio > src_ratio: fill_width = width // 2 - src_w // 2 - res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) - res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) + if fill_width > 0: + res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) + res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) return res From 8f6b24ce5922174d96eb9776126488cb28694ff8 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 12 Jul 2023 15:16:42 +0300 Subject: [PATCH 109/178] Add correct logger name --- modules/mac_specific.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 2c2f15ca..328b5973 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -5,7 +5,7 @@ import platform from modules.sd_hijack_utils import CondFunc from packaging import version -log = logging.getLogger() +log = logging.getLogger(__name__) # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, From 3d524fd3f1bdb17946bf6fa8a3cdf7b10859c495 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 12 Jul 2023 15:17:13 +0300 Subject: [PATCH 110/178] Don't do MPS GC when there's a latent that could still be sampled --- modules/mac_specific.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 328b5973..9ceb43ba 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -30,6 +30,10 @@ has_mps = check_for_mps() def torch_mps_gc() -> None: try: + from modules.shared import state + if state.current_latent is not None: + log.debug("`current_latent` is set, skipping MPS garbage collection") + return from torch.mps import empty_cache empty_cache() except Exception: From ea49bb06125262e61c44003e42219bc04e38b10b Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 12 Jul 2023 23:30:22 +0900 Subject: [PATCH 111/178] use submit blur for quick settings textbox --- modules/ui_settings.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 0c560b30..a6076bf3 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -260,13 +260,20 @@ class UiSettings: component = self.component_dict[k] info = opts.data_labels[k] - change_handler = component.release if hasattr(component, 'release') else component.change - change_handler( - fn=lambda value, k=k: self.run_settings_single(value, key=k), - inputs=[component], - outputs=[component, self.text_settings], - show_progress=info.refresh is not None, - ) + if isinstance(component, gr.Textbox): + methods = [component.submit, component.blur] + elif hasattr(component, 'release'): + methods = [component.release] + else: + methods = [component.change] + + for method in methods: + method( + fn=lambda value, k=k: self.run_settings_single(value, key=k), + inputs=[component], + outputs=[component, self.text_settings], + show_progress=info.refresh is not None, + ) button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) button_set_checkpoint.click( From da464a3fb39ecc6ea7b22fe87271194480d8501c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 12 Jul 2023 23:52:43 +0300 Subject: [PATCH 112/178] SDXL support --- modules/launch_utils.py | 17 ++++++++++ modules/lowvram.py | 49 ++++++++++++++++++++-------- modules/paths.py | 9 +++++- modules/processing.py | 7 ++-- modules/prompt_parser.py | 23 ++++++++++++-- modules/sd_hijack.py | 23 +++++++++++++- modules/sd_hijack_clip.py | 16 +++++++--- modules/sd_hijack_open_clip.py | 38 +++++++++++++++++++--- modules/sd_hijack_optimizations.py | 51 +++++++++++++++++++++++++----- modules/sd_models.py | 14 ++++++-- modules/sd_models_config.py | 5 ++- modules/sd_models_xl.py | 27 +++++++++++++--- modules/sd_samplers_kdiffusion.py | 2 +- modules/shared.py | 2 ++ requirements.txt | 1 + requirements_versions.txt | 1 + 16 files changed, 241 insertions(+), 44 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 3b740dbd..aa9d1880 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -224,6 +224,20 @@ def run_extensions_installers(settings_file): run_extension_installer(os.path.join(extensions_dir, dirname_extension)) +def mute_sdxl_imports(): + """create fake modules that SDXL wants to import but doesn't actually use for our purposes""" + + import importlib + + module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('taming.modules.losses.lpips', None)) + module.LPIPS = None + sys.modules['taming.modules.losses.lpips'] = module + + module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('sgm.data', None)) + module.StableDataModuleFromConfig = None + sys.modules['sgm.data'] = module + + def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") @@ -319,11 +333,14 @@ def prepare_environment(): if args.update_all_extensions: git_pull_recursive(extensions_dir) + mute_sdxl_imports() + if "--exit" in sys.argv: print("Exiting because of --exit argument") exit(0) + def configure_for_tests(): if "--api" not in sys.argv: sys.argv.append("--api") diff --git a/modules/lowvram.py b/modules/lowvram.py index d95bcfbf..da4f33a8 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -53,19 +53,46 @@ def setup_for_low_vram(sd_model, use_medvram): send_me_to_gpu(first_stage_model, None) return first_stage_model_decode(z) - # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field - if hasattr(sd_model.cond_stage_model, 'model'): - sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model + to_remain_in_cpu = [ + (sd_model, 'first_stage_model'), + (sd_model, 'depth_model'), + (sd_model, 'embedder'), + (sd_model, 'model'), + (sd_model, 'embedder'), + ] - # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then - # send the model to GPU. Then put modules back. the modules will be in CPU. - stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model - sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None + is_sdxl = hasattr(sd_model, 'conditioner') + is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model') + + if is_sdxl: + to_remain_in_cpu.append((sd_model, 'conditioner')) + elif is_sd2: + to_remain_in_cpu.append((sd_model.cond_stage_model, 'model')) + else: + to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer')) + + # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model + stored = [] + for obj, field in to_remain_in_cpu: + module = getattr(obj, field, None) + stored.append(module) + setattr(obj, field, None) + + # send the model to GPU. sd_model.to(devices.device) - sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored + + # put modules back. the modules will be in CPU. + for (obj, field), module in zip(to_remain_in_cpu, stored): + setattr(obj, field, module) # register hooks for those the first three models - sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) + if is_sdxl: + sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu) + elif is_sd2: + sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu) + else: + sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) + sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.encode = first_stage_model_encode_wrap sd_model.first_stage_model.decode = first_stage_model_decode_wrap @@ -75,10 +102,6 @@ def setup_for_low_vram(sd_model, use_medvram): sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model - if hasattr(sd_model.cond_stage_model, 'model'): - sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer - del sd_model.cond_stage_model.transformer - if use_medvram: sd_model.model.register_forward_pre_hook(send_me_to_gpu) else: diff --git a/modules/paths.py b/modules/paths.py index f509a85f..1100a8dc 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -20,7 +20,7 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), - (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', []), + (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), @@ -36,6 +36,13 @@ for d, must_exist, what, options in path_dirs: d = os.path.abspath(d) if "atstart" in options: sys.path.insert(0, d) + elif "sgm" in options: + # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we + # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir. + + sys.path.insert(0, d) + import sgm + sys.path.pop(0) else: sys.path.append(d) paths[what] = d diff --git a/modules/processing.py b/modules/processing.py index cd568a20..85d35423 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -343,10 +343,13 @@ class StableDiffusionProcessing: return cache[1] def setup_conds(self): + prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height) + negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height) + sampler_config = sd_samplers.find_sampler_config(self.sampler_name) self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 - self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) - self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) + self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) + self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) def parse_extra_network_prompts(self): self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index d7f9e9a9..33810669 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re from collections import namedtuple from typing import List @@ -109,7 +111,19 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) -def get_learned_conditioning(model, prompts, steps): +class SdConditioning(list): + """ + A list with prompts for stable diffusion's conditioner model. + Can also specify width and height of created image - SDXL needs it. + """ + def __init__(self, prompts, width=None, height=None): + super().__init__() + self.extend(prompts) + self.width = width or getattr(prompts, 'width', None) + self.height = height or getattr(prompts, 'height', None) + + +def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), and the sampling step at which this condition is to be replaced by the next one. @@ -160,11 +174,13 @@ def get_learned_conditioning(model, prompts, steps): re_AND = re.compile(r"\bAND\b") re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$") -def get_multicond_prompt_list(prompts): + +def get_multicond_prompt_list(prompts: SdConditioning | list[str]): res_indexes = [] - prompt_flat_list = [] prompt_indexes = {} + prompt_flat_list = SdConditioning(prompts) + prompt_flat_list.clear() for prompt in prompts: subprompts = re_AND.split(prompt) @@ -201,6 +217,7 @@ class MulticondLearnedConditioning: self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS self.batch: List[List[ComposableScheduledPromptConditioning]] = batch + def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. For each prompt, the list is obtained by splitting the prompt using the AND separator. diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c4b9211f..266811f9 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -15,6 +15,11 @@ import ldm.models.diffusion.ddim import ldm.models.diffusion.plms import ldm.modules.encoders.modules +import sgm.modules.attention +import sgm.modules.diffusionmodules.model +import sgm.modules.diffusionmodules.openaimodel +import sgm.modules.encoders.modules + attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward @@ -56,6 +61,9 @@ def apply_optimizations(option=None): ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th + sgm.modules.diffusionmodules.model.nonlinearity = silu + sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th + if current_optimizer is not None: current_optimizer.undo() current_optimizer = None @@ -89,6 +97,10 @@ def undo_optimizations(): ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity + sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + def fix_checkpoint(): """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want @@ -170,10 +182,19 @@ class StableDiffusionModelHijack: if conditioner: for i in range(len(conditioner.embedders)): embedder = conditioner.embedders[i] - if type(embedder).__name__ == 'FrozenOpenCLIPEmbedder': + typename = type(embedder).__name__ + if typename == 'FrozenOpenCLIPEmbedder': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) conditioner.embedders[i] = m.cond_stage_model + if typename == 'FrozenCLIPEmbedder': + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) + m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(embedder, self) + conditioner.embedders[i] = m.cond_stage_model + if typename == 'FrozenOpenCLIPEmbedder2': + embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) + conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self) if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 3b5a7666..6c17a81d 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): self.hijack: sd_hijack.StableDiffusionModelHijack = hijack self.chunk_length = 75 + self.is_trainable = getattr(wrapped, 'is_trainable', False) + self.input_key = getattr(wrapped, 'input_key', 'txt') + self.legacy_ucg_val = None + def empty_chunk(self): """creates an empty PromptChunk and returns it""" @@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): """ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will - be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024. + be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280. An example shape returned by this function can be: (2, 77, 768). + For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values. Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" """ @@ -233,7 +238,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()]) self.hijack.comments.append(f"Used embeddings: {embeddings_list}") - return torch.hstack(zs) + if getattr(self.wrapped, 'return_pooled', False): + return torch.hstack(zs), zs[0].pooled + else: + return torch.hstack(zs) def process_tokens(self, remade_batch_tokens, batch_multipliers): """ @@ -256,9 +264,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) original_mean = z.mean() - z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() - z = z * (original_mean / new_mean) + z *= (original_mean / new_mean) return z diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py index 6ac5bda6..fcf5ad07 100644 --- a/modules/sd_hijack_open_clip.py +++ b/modules/sd_hijack_open_clip.py @@ -16,10 +16,6 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit self.id_end = tokenizer.encoder[""] self.id_pad = 0 - self.is_trainable = getattr(wrapped, 'is_trainable', False) - self.input_key = getattr(wrapped, 'input_key', 'txt') - self.legacy_ucg_val = None - def tokenize(self, texts): assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' @@ -39,3 +35,37 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) return embedded + + +class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] + self.id_start = tokenizer.encoder[""] + self.id_end = tokenizer.encoder[""] + self.id_pad = 0 + + def tokenize(self, texts): + assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' + + tokenized = [tokenizer.encode(text) for text in texts] + + return tokenized + + def encode_with_transformers(self, tokens): + d = self.wrapped.encode_with_transformer(tokens) + z = d[self.wrapped.layer] + + pooled = d.get("pooled") + if pooled is not None: + z.pooled = pooled + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + ids = tokenizer.encode(init_text) + ids = torch.asarray([ids], device=devices.device, dtype=torch.int) + embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) + + return embedded diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 53e27ade..e99c9ba5 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -14,7 +14,11 @@ from modules.hypernetworks import hypernetwork import ldm.modules.attention import ldm.modules.diffusionmodules.model +import sgm.modules.attention +import sgm.modules.diffusionmodules.model + diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward +sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward class SdOptimization: @@ -39,6 +43,9 @@ class SdOptimization: ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward + class SdOptimizationXformers(SdOptimization): name = "xformers" @@ -51,6 +58,8 @@ class SdOptimizationXformers(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward + sgm.modules.attention.CrossAttention.forward = xformers_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward class SdOptimizationSdpNoMem(SdOptimization): @@ -65,6 +74,8 @@ class SdOptimizationSdpNoMem(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward + sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward class SdOptimizationSdp(SdOptimizationSdpNoMem): @@ -76,6 +87,8 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem): def apply(self): ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward + sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward class SdOptimizationSubQuad(SdOptimization): @@ -86,6 +99,8 @@ class SdOptimizationSubQuad(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward + sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward class SdOptimizationV1(SdOptimization): @@ -94,9 +109,9 @@ class SdOptimizationV1(SdOptimization): cmd_opt = "opt_split_attention_v1" priority = 10 - def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 class SdOptimizationInvokeAI(SdOptimization): @@ -109,6 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI class SdOptimizationDoggettx(SdOptimization): @@ -119,6 +135,8 @@ class SdOptimizationDoggettx(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward def list_optimizers(res): @@ -155,7 +173,7 @@ def get_available_vram(): # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None): +def split_cross_attention_forward_v1(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): h = self.heads q_in = self.to_q(x) @@ -196,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): # taken from https://github.com/Doggettx/stable-diffusion and modified -def split_cross_attention_forward(self, x, context=None, mask=None): +def split_cross_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): h = self.heads q_in = self.to_q(x) @@ -262,11 +280,13 @@ def split_cross_attention_forward(self, x, context=None, mask=None): # -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- mem_total_gb = psutil.virtual_memory().total // (1 << 30) + def einsum_op_compvis(q, k, v): s = einsum('b i d, b j d -> b i j', q, k) s = s.softmax(dim=-1, dtype=s.dtype) return einsum('b i j, b j d -> b i d', s, v) + def einsum_op_slice_0(q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[0], slice_size): @@ -274,6 +294,7 @@ def einsum_op_slice_0(q, k, v, slice_size): r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) return r + def einsum_op_slice_1(q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[1], slice_size): @@ -281,6 +302,7 @@ def einsum_op_slice_1(q, k, v, slice_size): r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) return r + def einsum_op_mps_v1(q, k, v): if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 return einsum_op_compvis(q, k, v) @@ -290,12 +312,14 @@ def einsum_op_mps_v1(q, k, v): slice_size -= 1 return einsum_op_slice_1(q, k, v, slice_size) + def einsum_op_mps_v2(q, k, v): if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: return einsum_op_compvis(q, k, v) else: return einsum_op_slice_0(q, k, v, 1) + def einsum_op_tensor_mem(q, k, v, max_tensor_mb): size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) if size_mb <= max_tensor_mb: @@ -305,6 +329,7 @@ def einsum_op_tensor_mem(q, k, v, max_tensor_mb): return einsum_op_slice_0(q, k, v, q.shape[0] // div) return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + def einsum_op_cuda(q, k, v): stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] @@ -315,6 +340,7 @@ def einsum_op_cuda(q, k, v): # Divide factor of safety as there's copying and fragmentation return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + def einsum_op(q, k, v): if q.device.type == 'cuda': return einsum_op_cuda(q, k, v) @@ -328,7 +354,8 @@ def einsum_op(q, k, v): # Tested on i7 with 8MB L3 cache. return einsum_op_tensor_mem(q, k, v, 32) -def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): + +def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): h = self.heads q = self.to_q(x) @@ -356,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface -def sub_quad_attention_forward(self, x, context=None, mask=None): +def sub_quad_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." h = self.heads @@ -392,6 +419,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): return x + def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): bytes_per_token = torch.finfo(q.dtype).bits//8 batch_x_heads, q_tokens, _ = q.shape @@ -442,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v): return None -def xformers_attention_forward(self, x, context=None, mask=None): +def xformers_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): h = self.heads q_in = self.to_q(x) context = default(context, x) @@ -465,9 +493,10 @@ def xformers_attention_forward(self, x, context=None, mask=None): out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) + # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface -def scaled_dot_product_attention_forward(self, x, context=None, mask=None): +def scaled_dot_product_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): batch_size, sequence_length, inner_dim = x.shape if mask is not None: @@ -507,10 +536,12 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None): hidden_states = self.to_out[1](hidden_states) return hidden_states -def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None): + +def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): return scaled_dot_product_attention_forward(self, x, context, mask) + def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) @@ -569,6 +600,7 @@ def cross_attention_attnblock_forward(self, x): return h3 + def xformers_attnblock_forward(self, x): try: h_ = x @@ -592,6 +624,7 @@ def xformers_attnblock_forward(self, x): except NotImplementedError: return cross_attention_attnblock_forward(self, x) + def sdp_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) @@ -612,10 +645,12 @@ def sdp_attnblock_forward(self, x): out = self.proj_out(out) return x + out + def sdp_no_mem_attnblock_forward(self, x): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): return sdp_attnblock_forward(self, x) + def sub_quad_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) diff --git a/modules/sd_models.py b/modules/sd_models.py index 8d639583..e4aae597 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -411,6 +411,7 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' +sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight' class SdModelData: @@ -445,6 +446,15 @@ class SdModelData: model_data = SdModelData() +def get_empty_cond(sd_model): + if hasattr(sd_model, 'conditioner'): + d = sd_model.get_learned_conditioning([""]) + return d['crossattn'] + else: + return sd_model.cond_stage_model([""]) + + + def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -465,7 +475,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict + clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict or sdxl_clip_weight in state_dict timer.record("find config") @@ -517,7 +527,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("scripts callbacks") with devices.autocast(), torch.no_grad(): - sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""]) + sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model) timer.record("calculate empty prompt") diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 96501569..2e92479a 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -14,6 +14,7 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") +config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") @@ -70,7 +71,9 @@ def guess_model_config_from_state_dict(sd, filename): diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) - if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: + if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: + return config_sdxl + elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: return config_depth_model elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: return config_unclip diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index d43b8868..e8e270c3 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -1,18 +1,30 @@ from __future__ import annotations +import sys + import torch import sgm.models.diffusion import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.discretizer -from modules import devices +from modules import devices, shared, prompt_parser -def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: list[str]): +def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): for embedder in self.conditioner.embedders: embedder.ucg_rate = 0.0 - c = self.conditioner({'txt': batch}) + width = getattr(self, 'target_width', 1024) + height = getattr(self, 'target_height', 1024) + + sdxl_conds = { + "txt": batch, + "original_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), + "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left]).repeat(len(batch), 1).to(devices.device, devices.dtype), + "target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), + } + + c = self.conditioner(sdxl_conds) return c @@ -26,7 +38,7 @@ def extend_sdxl(model): model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' - model.cond_stage_model = [x for x in model.conditioner.embedders if type(x).__name__ == 'FrozenOpenCLIPEmbedder'][0] + model.cond_stage_model = [x for x in model.conditioner.embedders if 'CLIPEmbedder' in type(x).__name__][0] model.cond_stage_key = model.cond_stage_model.input_key model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" @@ -34,7 +46,14 @@ def extend_sdxl(model): discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + model.is_xl = True + sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning sgm.models.diffusion.DiffusionEngine.apply_model = apply_model +sgm.modules.attention.print = lambda *args: None +sgm.modules.diffusionmodules.model.print = lambda *args: None +sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None +sgm.modules.encoders.modules.print = lambda *args: None + diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 73289ce4..5552a8dc 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -186,7 +186,7 @@ class CFGDenoiser(torch.nn.Module): for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(cond_in[a:b], image_cond_in[a:b])) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size diff --git a/modules/shared.py b/modules/shared.py index b7518de6..71afd94f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -428,6 +428,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), + "sdxl_crop_top": OptionInfo(0, "SDXL top coordinate of the crop"), + "sdxl_crop_left": OptionInfo(0, "SDXL left coordinate of the crop"), })) options_templates.update(options_section(('optimizations', "Optimizations"), { diff --git a/requirements.txt b/requirements.txt index 3142085e..b3f8a7f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ kornia lark numpy omegaconf +open-clip-torch piexif psutil diff --git a/requirements_versions.txt b/requirements_versions.txt index f71b9d6c..b826bf43 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -15,6 +15,7 @@ kornia==0.6.7 lark==1.1.2 numpy==1.23.5 omegaconf==2.2.3 +open-clip-torch==2.20.0 piexif==1.1.3 psutil~=5.9.5 pytorch_lightning==1.9.4 From 5cf623c58ef3c158e8b25f7c3d516ffc16769fa4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 00:08:19 +0300 Subject: [PATCH 113/178] linter --- modules/paths.py | 2 +- modules/sd_models_xl.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/modules/paths.py b/modules/paths.py index 1100a8dc..c6f8904e 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -41,7 +41,7 @@ for d, must_exist, what, options in path_dirs: # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir. sys.path.insert(0, d) - import sgm + import sgm # noqa: F401 sys.path.pop(0) else: sys.path.append(d) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index e8e270c3..9224c1a3 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -1,7 +1,5 @@ from __future__ import annotations -import sys - import torch import sgm.models.diffusion From a04c95512148fc6df64535a995fbc8f499cae206 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 00:12:25 +0300 Subject: [PATCH 114/178] fix importlib.machinery issue on github's autotests #yolo --- modules/launch_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index aa9d1880..4f48f3a1 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -227,13 +227,14 @@ def run_extensions_installers(settings_file): def mute_sdxl_imports(): """create fake modules that SDXL wants to import but doesn't actually use for our purposes""" - import importlib + class Dummy: + pass - module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('taming.modules.losses.lpips', None)) + module = Dummy() module.LPIPS = None sys.modules['taming.modules.losses.lpips'] = module - module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('sgm.data', None)) + module = Dummy() module.StableDataModuleFromConfig = None sys.modules['sgm.data'] = module From b717eb7e56a4e620e77a2225e80223c89cb4f0d1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 08:29:37 +0300 Subject: [PATCH 115/178] mute unneeded SDXL imports for tests too --- modules/launch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 4f48f3a1..56b972d5 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -334,8 +334,6 @@ def prepare_environment(): if args.update_all_extensions: git_pull_recursive(extensions_dir) - mute_sdxl_imports() - if "--exit" in sys.argv: print("Exiting because of --exit argument") exit(0) @@ -357,6 +355,8 @@ def configure_for_tests(): def start(): + mute_sdxl_imports() + print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}") import webui if '--nowebui' in sys.argv: From ac4ccfa1369e74492b467294eab96c3f558b297b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 09:30:33 +0300 Subject: [PATCH 116/178] get attention optimizations to work --- modules/hypernetworks/hypernetwork.py | 2 +- modules/launch_utils.py | 1 + modules/sd_hijack_optimizations.py | 14 +++++++------- modules/sd_models_xl.py | 3 +++ 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 79670b87..c4821d21 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -378,7 +378,7 @@ def apply_hypernetworks(hypernetworks, context, layer=None): return context_k, context_v -def attention_CrossAttention_forward(self, x, context=None, mask=None): +def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs): h = self.heads q = self.to_q(x) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 56b972d5..183730d2 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -239,6 +239,7 @@ def mute_sdxl_imports(): sys.modules['sgm.data'] = module + def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index e99c9ba5..b5f85ba5 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -173,7 +173,7 @@ def get_available_vram(): # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs): h = self.heads q_in = self.to_q(x) @@ -214,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None, additiona # taken from https://github.com/Doggettx/stable-diffusion and modified -def split_cross_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs): h = self.heads q_in = self.to_q(x) @@ -355,7 +355,7 @@ def einsum_op(q, k, v): return einsum_op_tensor_mem(q, k, v, 32) -def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs): h = self.heads q = self.to_q(x) @@ -383,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, add # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface -def sub_quad_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs): assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." h = self.heads @@ -470,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v): return None -def xformers_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def xformers_attention_forward(self, x, context=None, mask=None, **kwargs): h = self.heads q_in = self.to_q(x) context = default(context, x) @@ -496,7 +496,7 @@ def xformers_attention_forward(self, x, context=None, mask=None, additional_toke # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface -def scaled_dot_product_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs): batch_size, sequence_length, inner_dim = x.shape if mask is not None: @@ -537,7 +537,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None, addit return hidden_states -def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): return scaled_dot_product_attention_forward(self, x, context, mask) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 9224c1a3..4d1aa497 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -55,3 +55,6 @@ sgm.modules.diffusionmodules.model.print = lambda *args: None sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None sgm.modules.encoders.modules.print = lambda *args: None +# this gets the code to load the vanilla attention that we override +sgm.modules.attention.SDP_IS_AVAILABLE = True +sgm.modules.attention.XFORMERS_IS_AVAILABLE = False \ No newline at end of file From 21aec6f567f52271efbbe33a2ab6561f9a47b787 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 09:38:54 +0300 Subject: [PATCH 117/178] lint --- modules/sd_models_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 4d1aa497..1dd4459f 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -57,4 +57,4 @@ sgm.modules.encoders.modules.print = lambda *args: None # this gets the code to load the vanilla attention that we override sgm.modules.attention.SDP_IS_AVAILABLE = True -sgm.modules.attention.XFORMERS_IS_AVAILABLE = False \ No newline at end of file +sgm.modules.attention.XFORMERS_IS_AVAILABLE = False From 594c8e7b263d9b37f4b18b56b159aeb6d1bba1b4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 11:35:52 +0300 Subject: [PATCH 118/178] fix CLIP doing the unneeded normalization revert SD2.1 back to use the original repo add SDXL's force_zero_embeddings to negative prompt --- modules/processing.py | 2 +- modules/prompt_parser.py | 14 ++++++++++---- modules/sd_hijack.py | 2 +- modules/sd_hijack_clip.py | 15 +++++++++++++++ modules/sd_models_config.py | 1 - modules/sd_models_xl.py | 3 ++- 6 files changed, 29 insertions(+), 8 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 85d35423..f01a6907 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -344,7 +344,7 @@ class StableDiffusionProcessing: def setup_conds(self): prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height) - negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height) + negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True) sampler_config = sd_samplers.find_sampler_config(self.sampler_name) self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 33810669..b29d079d 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -116,11 +116,17 @@ class SdConditioning(list): A list with prompts for stable diffusion's conditioner model. Can also specify width and height of created image - SDXL needs it. """ - def __init__(self, prompts, width=None, height=None): + def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None): super().__init__() self.extend(prompts) - self.width = width or getattr(prompts, 'width', None) - self.height = height or getattr(prompts, 'height', None) + + if copy_from is None: + copy_from = prompts + + self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False) + self.width = width or getattr(copy_from, 'width', None) + self.height = height or getattr(copy_from, 'height', None) + def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): @@ -153,7 +159,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): res.append(cached) continue - texts = [x[1] for x in prompt_schedule] + texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts) conds = model.get_learned_conditioning(texts) cond_schedule = [] diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 266811f9..647cdfbe 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -190,7 +190,7 @@ class StableDiffusionModelHijack: if typename == 'FrozenCLIPEmbedder': model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(embedder, self) + m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) conditioner.embedders[i] = m.cond_stage_model if typename == 'FrozenOpenCLIPEmbedder2': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 6c17a81d..b3771909 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -323,3 +323,18 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0) return embedded + + +class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + def encode_with_transformers(self, tokens): + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden") + + if self.wrapped.layer == "last": + z = outputs.last_hidden_state + else: + z = outputs.hidden_states[self.wrapped.layer_idx] + + return z diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 2e92479a..04c09ab0 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -12,7 +12,6 @@ sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "conf config_default = shared.sd_default_config config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") -config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 1dd4459f..b799ff46 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -22,7 +22,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: "target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), } - c = self.conditioner(sdxl_conds) + force_zero_negative_prompt = getattr(batch, 'is_negative_prompt', False) and all(x == '' for x in batch) + c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) return c From 76ebb175ca996e93c063e7109c9f478a268952b6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 12:59:31 +0300 Subject: [PATCH 119/178] lora support --- extensions-builtin/Lora/lora.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index cd46e6c7..03f1ef85 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -179,6 +179,11 @@ def load_lora(name, lora_on_disk): if m: sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) + # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" + if sd_module is None and "lora_unet" in key_diffusers_without_lora_parts: + key = key_diffusers_without_lora_parts.replace("lora_unet", "diffusion_model") + sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + if sd_module is None: keys_failed_to_match[key_diffusers] = key continue From 6f23da603d3cbba82262a3c62cc44c8d5cb9e6db Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 16:18:39 +0300 Subject: [PATCH 120/178] fix broken img2img --- modules/sd_models_xl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index b799ff46..b19036f1 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -32,6 +32,9 @@ def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): return self.model(x, t, cond) +def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility + return x + def extend_sdxl(model): dtype = next(model.model.diffusion_model.parameters()).dtype model.model.diffusion_model.dtype = dtype @@ -50,6 +53,7 @@ def extend_sdxl(model): sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning sgm.models.diffusion.DiffusionEngine.apply_model = apply_model +sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding sgm.modules.attention.print = lambda *args: None sgm.modules.diffusionmodules.model.print = lambda *args: None From b8159d0919dcaa3a1a8f29e3aa30c25fe8e5f13b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 17:24:54 +0300 Subject: [PATCH 121/178] add XL support for live previews: approx and TAESD --- modules/sd_models_xl.py | 2 +- modules/sd_vae_approx.py | 37 ++++++++++++++++++++++++++----------- modules/sd_vae_taesd.py | 26 +++++++++++++------------- 3 files changed, 40 insertions(+), 25 deletions(-) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index b19036f1..af445a61 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -48,7 +48,7 @@ def extend_sdxl(model): discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) - model.is_xl = True + model.is_sdxl = True sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index e2f00468..b348f3ae 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -2,9 +2,9 @@ import os import torch from torch import nn -from modules import devices, paths +from modules import devices, paths, shared -sd_vae_approx_model = None +sd_vae_approx_models = {} class VAEApprox(nn.Module): @@ -31,19 +31,34 @@ class VAEApprox(nn.Module): return x +def download_model(model_path, model_url): + if not os.path.exists(model_path): + os.makedirs(os.path.dirname(model_path), exist_ok=True) + + print(f'Downloading VAEApprox model to: {model_path}') + torch.hub.download_url_to_file(model_url, model_path) + + def model(): - global sd_vae_approx_model + model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt" + loaded_model = sd_vae_approx_models.get(model_name) - if sd_vae_approx_model is None: - model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt") - sd_vae_approx_model = VAEApprox() + if loaded_model is None: + model_path = os.path.join(paths.models_path, "VAE-approx", model_name) if not os.path.exists(model_path): - model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt") - sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) - sd_vae_approx_model.eval() - sd_vae_approx_model.to(devices.device, devices.dtype) + model_path = os.path.join(paths.script_path, "models", "VAE-approx", model_name) - return sd_vae_approx_model + if not os.path.exists(model_path): + model_path = os.path.join(paths.models_path, "VAE-approx", model_name) + download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name) + + loaded_model = VAEApprox() + loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) + loaded_model.eval() + loaded_model.to(devices.device, devices.dtype) + sd_vae_approx_models[model_name] = loaded_model + + return loaded_model def cheap_approximation(sample): diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 5e8496e8..5bf7c76e 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -8,9 +8,9 @@ import os import torch import torch.nn as nn -from modules import devices, paths_internal +from modules import devices, paths_internal, shared -sd_vae_taesd = None +sd_vae_taesd_models = {} def conv(n_in, n_out, **kwargs): @@ -61,9 +61,7 @@ class TAESD(nn.Module): return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) -def download_model(model_path): - model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' - +def download_model(model_path, model_url): if not os.path.exists(model_path): os.makedirs(os.path.dirname(model_path), exist_ok=True) @@ -72,17 +70,19 @@ def download_model(model_path): def model(): - global sd_vae_taesd + model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) - if sd_vae_taesd is None: - model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth") - download_model(model_path) + if loaded_model is None: + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name) + download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) if os.path.exists(model_path): - sd_vae_taesd = TAESD(model_path) - sd_vae_taesd.eval() - sd_vae_taesd.to(devices.device, devices.dtype) + loaded_model = TAESD(model_path) + loaded_model.eval() + loaded_model.to(devices.device, devices.dtype) + sd_vae_taesd_models[model_name] = loaded_model else: raise FileNotFoundError('TAESD model not found') - return sd_vae_taesd.decoder + return loaded_model.decoder From e16ebc917dfc902f041963df0d4e99e8141cf82f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 17:32:35 +0300 Subject: [PATCH 122/178] repair --no-half for SDXL --- modules/sd_models.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index e4aae597..9e8cb3cf 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -395,10 +395,11 @@ def repair_config(sd_config): if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False - if shared.cmd_opts.no_half: - sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.cmd_opts.upcast_sampling: - sd_config.model.params.unet_config.params.use_fp16 = True + if hasattr(sd_config.model.params, 'unet_config'): + if shared.cmd_opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" From ff73841c608f5f02e6352bb235d9dbf63d922990 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 17:42:16 +0300 Subject: [PATCH 123/178] mute SDXL imports in the place there SDXL is imported for the first time instead of launch.py --- modules/launch_utils.py | 18 ------------------ modules/paths.py | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 183730d2..01ea7c91 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -224,22 +224,6 @@ def run_extensions_installers(settings_file): run_extension_installer(os.path.join(extensions_dir, dirname_extension)) -def mute_sdxl_imports(): - """create fake modules that SDXL wants to import but doesn't actually use for our purposes""" - - class Dummy: - pass - - module = Dummy() - module.LPIPS = None - sys.modules['taming.modules.losses.lpips'] = module - - module = Dummy() - module.StableDataModuleFromConfig = None - sys.modules['sgm.data'] = module - - - def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") @@ -356,8 +340,6 @@ def configure_for_tests(): def start(): - mute_sdxl_imports() - print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}") import webui if '--nowebui' in sys.argv: diff --git a/modules/paths.py b/modules/paths.py index c6f8904e..25052339 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -5,6 +5,21 @@ from modules.paths_internal import models_path, script_path, data_path, extensio import modules.safe # noqa: F401 +def mute_sdxl_imports(): + """create fake modules that SDXL wants to import but doesn't actually use for our purposes""" + + class Dummy: + pass + + module = Dummy() + module.LPIPS = None + sys.modules['taming.modules.losses.lpips'] = module + + module = Dummy() + module.StableDataModuleFromConfig = None + sys.modules['sgm.data'] = module + + # data_path = cmd_opts_pre.data sys.path.insert(0, script_path) @@ -18,6 +33,8 @@ for possible_sd_path in possible_sd_paths: assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}" +mute_sdxl_imports() + path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), From 6c5f83b19b331d51bde28c5033d13d0d64c11e54 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 21:17:50 +0300 Subject: [PATCH 124/178] add support for SDXL loras with te1/te2 modules --- extensions-builtin/Lora/lora.py | 41 +++++++++++++++++++++++++-------- modules/sd_models.py | 3 ++- modules/sd_models_xl.py | 1 - 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 03f1ef85..4b5da7b5 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -68,6 +68,14 @@ def convert_diffusers_name_to_compvis(key, is_sd2): return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): + if 'mlp_fc1' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + return key @@ -142,10 +150,20 @@ class LoraUpDownModule: def assign_lora_names_to_compvis_modules(sd_model): lora_layer_mapping = {} - for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): - lora_name = name.replace(".", "_") - lora_layer_mapping[lora_name] = module - module.lora_layer_name = lora_name + if shared.sd_model.is_sdxl: + for i, embedder in enumerate(shared.sd_model.conditioner.embedders): + if not hasattr(embedder, 'wrapped'): + continue + + for name, module in embedder.wrapped.named_modules(): + lora_name = f'{i}_{name.replace(".", "_")}' + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + else: + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name for name, module in shared.sd_model.model.named_modules(): lora_name = name.replace(".", "_") @@ -168,10 +186,10 @@ def load_lora(name, lora_on_disk): keys_failed_to_match = {} is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping - for key_diffusers, weight in sd.items(): - key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1) - key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2) + for key_lora, weight in sd.items(): + key_lora_without_lora_parts, lora_key = key_lora.split(".", 1) + key = convert_diffusers_name_to_compvis(key_lora_without_lora_parts, is_sd2) sd_module = shared.sd_model.lora_layer_mapping.get(key, None) if sd_module is None: @@ -180,12 +198,15 @@ def load_lora(name, lora_on_disk): sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" - if sd_module is None and "lora_unet" in key_diffusers_without_lora_parts: - key = key_diffusers_without_lora_parts.replace("lora_unet", "diffusion_model") + if sd_module is None and "lora_unet" in key_lora_without_lora_parts: + key = key_lora_without_lora_parts.replace("lora_unet", "diffusion_model") + sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + elif sd_module is None and "lora_te1_text_model" in key_lora_without_lora_parts: + key = key_lora_without_lora_parts.replace("lora_te1_text_model", "0_transformer_text_model") sd_module = shared.sd_model.lora_layer_mapping.get(key, None) if sd_module is None: - keys_failed_to_match[key_diffusers] = key + keys_failed_to_match[key_lora] = key continue lora_module = lora.modules.get(key, None) diff --git a/modules/sd_models.py b/modules/sd_models.py index 9e8cb3cf..07702175 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -289,7 +289,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if hasattr(model, 'conditioner'): + model.is_sdxl = hasattr(model, 'conditioner') + if model.is_sdxl: sd_models_xl.extend_sdxl(model) model.load_state_dict(state_dict, strict=False) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index af445a61..a7240dc0 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -48,7 +48,6 @@ def extend_sdxl(model): discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) - model.is_sdxl = True sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning From dc3906185656dae75fcefe96625b1dcd0d31579c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 21:19:41 +0300 Subject: [PATCH 125/178] thank you linter --- extensions-builtin/Lora/lora.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 4b5da7b5..302490fb 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -229,9 +229,9 @@ def load_lora(name, lora_on_disk): elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3): module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False) else: - print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') + print(f'Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}') continue - raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}") + raise AssertionError(f"Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}") with torch.no_grad(): module.weight.copy_(weight) @@ -243,7 +243,7 @@ def load_lora(name, lora_on_disk): elif lora_key == "lora_down.weight": lora_module.down = module else: - raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha") + raise AssertionError(f"Bad Lora layer name: {key_lora} - must end in lora_up.weight, lora_down.weight or alpha") if keys_failed_to_match: print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}") From a3db187e4f1ef59a571b1023309923f5e5e6dda3 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 14 Jul 2023 05:48:14 +0900 Subject: [PATCH 126/178] handles model hash cache.json error --- modules/hashes.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/modules/hashes.py b/modules/hashes.py index 8b7ea0ac..ec1187fe 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -5,7 +5,7 @@ import os.path import filelock from modules import shared -from modules.paths import data_path +from modules.paths import data_path, script_path cache_filename = os.path.join(data_path, "cache.json") @@ -26,8 +26,13 @@ def cache(subsection): if not os.path.isfile(cache_filename): cache_data = {} else: - with open(cache_filename, "r", encoding="utf8") as file: - cache_data = json.load(file) + try: + with open(cache_filename, "r", encoding="utf8") as file: + cache_data = json.load(file) + except Exception: + os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) + print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') + cache_data = {} s = cache_data.get(subsection, {}) cache_data[subsection] = s From 6d8dcdefa07d5f8f7e528046b0facdcc51185e60 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:16:01 +0300 Subject: [PATCH 127/178] initial SDXL refiner support --- modules/sd_hijack.py | 18 ++++++++---- modules/sd_models.py | 3 +- modules/sd_models_config.py | 3 ++ modules/sd_models_xl.py | 57 ++++++++++++++++++++++++++++++------- modules/shared.py | 9 ++++-- 5 files changed, 71 insertions(+), 19 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 647cdfbe..2b274c18 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -180,21 +180,29 @@ class StableDiffusionModelHijack: def hijack(self, m): conditioner = getattr(m, 'conditioner', None) if conditioner: + text_cond_models = [] + for i in range(len(conditioner.embedders)): embedder = conditioner.embedders[i] typename = type(embedder).__name__ if typename == 'FrozenOpenCLIPEmbedder': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) - m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) - conditioner.embedders[i] = m.cond_stage_model + conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) if typename == 'FrozenCLIPEmbedder': - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings = embedder.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) - conditioner.embedders[i] = m.cond_stage_model + conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) if typename == 'FrozenOpenCLIPEmbedder2': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) + + if len(text_cond_models) == 1: + m.cond_stage_model = text_cond_models[0] + else: + m.cond_stage_model = conditioner if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings diff --git a/modules/sd_models.py b/modules/sd_models.py index 07702175..267f4d8e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -414,6 +414,7 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight' +sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight' class SdModelData: @@ -477,7 +478,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict or sdxl_clip_weight in state_dict + clip_is_included_into_sd = any([x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict]) timer.record("find config") diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 04c09ab0..8266fa39 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -14,6 +14,7 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") +config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") @@ -72,6 +73,8 @@ def guess_model_config_from_state_dict(sd, filename): if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: return config_sdxl + if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: + return config_sdxl_refiner elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: return config_depth_model elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index a7240dc0..01320c7a 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -14,15 +14,20 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: width = getattr(self, 'target_width', 1024) height = getattr(self, 'target_height', 1024) + is_negative_prompt = getattr(batch, 'is_negative_prompt', False) + aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score + + devices_args = dict(device=devices.device, dtype=devices.dtype) sdxl_conds = { "txt": batch, - "original_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), - "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left]).repeat(len(batch), 1).to(devices.device, devices.dtype), - "target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), + "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), + "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1), + "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), + "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1), } - force_zero_negative_prompt = getattr(batch, 'is_negative_prompt', False) and all(x == '' for x in batch) + force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch) c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) return c @@ -35,25 +40,55 @@ def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility return x + +sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning +sgm.models.diffusion.DiffusionEngine.apply_model = apply_model +sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding + + +def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt): + res = [] + + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]: + encoded = embedder.encode_embedding_init_text(init_text, nvpt) + res.append(encoded) + + return torch.cat(res, dim=1) + + +def process_texts(self, texts): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: + return embedder.process_texts(texts) + + +def get_target_prompt_token_count(self, token_count): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]: + return embedder.get_target_prompt_token_count(token_count) + + +# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist +sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text +sgm.modules.GeneralConditioner.process_texts = process_texts +sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count + + def extend_sdxl(model): + """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" + dtype = next(model.model.diffusion_model.parameters()).dtype model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' - - model.cond_stage_model = [x for x in model.conditioner.embedders if 'CLIPEmbedder' in type(x).__name__][0] - model.cond_stage_key = model.cond_stage_model.input_key + model.cond_stage_key = 'txt' + # model.cond_stage_model will be set in sd_hijack model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + model.conditioner.wrapped = torch.nn.Module() -sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning -sgm.models.diffusion.DiffusionEngine.apply_model = apply_model -sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding - sgm.modules.attention.print = lambda *args: None sgm.modules.diffusionmodules.model.print = lambda *args: None sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None diff --git a/modules/shared.py b/modules/shared.py index 71afd94f..234ede0d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -428,8 +428,13 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), - "sdxl_crop_top": OptionInfo(0, "SDXL top coordinate of the crop"), - "sdxl_crop_left": OptionInfo(0, "SDXL left coordinate of the crop"), +})) + +options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { + "sdxl_crop_top": OptionInfo(0, "crop top coordinate"), + "sdxl_crop_left": OptionInfo(0, "crop left coordinate"), + "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"), + "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), })) options_templates.update(options_section(('optimizations', "Optimizations"), { From b7dbeda0d9e475aafa9db0cfe015bf724502ec20 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:19:08 +0300 Subject: [PATCH 128/178] linter --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 267f4d8e..729f03d7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -478,7 +478,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = any([x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict]) + clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict) timer.record("find config") From abb948dab09841571dd24c6be9ff9d6b212778ea Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:28:01 +0300 Subject: [PATCH 129/178] raise maximum Negative Guidance minimum sigma due to request in PR discussion --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index 234ede0d..89b7132e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -439,7 +439,7 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { options_templates.update(options_section(('optimizations', "Optimizations"), { "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), - "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), + "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), From 714c920c20d07d70a0dd07c8c5cb54d9378e92c4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:47:44 +0300 Subject: [PATCH 130/178] do not run workflow items twice for PRs from this repo update names --- .github/workflows/on_pull_request.yaml | 6 +++++- .github/workflows/run_tests.yaml | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/on_pull_request.yaml b/.github/workflows/on_pull_request.yaml index 8ebf5918..c56eea6b 100644 --- a/.github/workflows/on_pull_request.yaml +++ b/.github/workflows/on_pull_request.yaml @@ -1,4 +1,4 @@ -name: Run Linting/Formatting on Pull Requests +name: Linter on: - push @@ -6,7 +6,9 @@ on: jobs: lint-python: + name: Python linter runs-on: ubuntu-latest + if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name steps: - name: Checkout Code uses: actions/checkout@v3 @@ -22,7 +24,9 @@ jobs: - name: Run Ruff run: ruff . lint-js: + name: Javascript linter runs-on: ubuntu-latest + if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name steps: - name: Checkout Code uses: actions/checkout@v3 diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 178c026a..2af21448 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -1,4 +1,4 @@ -name: Run basic features tests on CPU with empty SD model +name: Tests on: - push @@ -6,7 +6,9 @@ on: jobs: test: + name: Tests on CPU with empty model runs-on: ubuntu-latest + if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name steps: - name: Checkout Code uses: actions/checkout@v3 From 9a3f35b028a8026291679c35e1df5b2aea327a1d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:56:01 +0300 Subject: [PATCH 131/178] repair medvram and lowvram --- modules/lowvram.py | 4 +++- modules/sd_hijack_open_clip.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/lowvram.py b/modules/lowvram.py index da4f33a8..6bbc11eb 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -100,7 +100,9 @@ def setup_for_low_vram(sd_model, use_medvram): sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu) if sd_model.embedder: sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) - parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model + + if hasattr(sd_model, 'cond_stage_model'): + parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model if use_medvram: sd_model.model.register_forward_pre_hook(send_me_to_gpu) diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py index fcf5ad07..bb0b96c7 100644 --- a/modules/sd_hijack_open_clip.py +++ b/modules/sd_hijack_open_clip.py @@ -32,7 +32,7 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit def encode_embedding_init_text(self, init_text, nvpt): ids = tokenizer.encode(init_text) ids = torch.asarray([ids], device=devices.device, dtype=torch.int) - embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) + embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0) return embedded @@ -66,6 +66,6 @@ class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWi def encode_embedding_init_text(self, init_text, nvpt): ids = tokenizer.encode(init_text) ids = torch.asarray([ids], device=devices.device, dtype=torch.int) - embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) + embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0) return embedded From 62e32634677f872d0325a8c9330bb7c12fe1f310 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 10:07:08 +0300 Subject: [PATCH 132/178] edit names more --- .github/workflows/on_pull_request.yaml | 4 ++-- .github/workflows/run_tests.yaml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/on_pull_request.yaml b/.github/workflows/on_pull_request.yaml index c56eea6b..78e608ee 100644 --- a/.github/workflows/on_pull_request.yaml +++ b/.github/workflows/on_pull_request.yaml @@ -6,7 +6,7 @@ on: jobs: lint-python: - name: Python linter + name: ruff runs-on: ubuntu-latest if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name steps: @@ -24,7 +24,7 @@ jobs: - name: Run Ruff run: ruff . lint-js: - name: Javascript linter + name: eslint runs-on: ubuntu-latest if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name steps: diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 2af21448..e9370cc0 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -6,7 +6,7 @@ on: jobs: test: - name: Tests on CPU with empty model + name: tests on CPU with empty model runs-on: ubuntu-latest if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name steps: From 471a5a66b73921d569242daccc5275cb195e3f06 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 17:54:09 +0300 Subject: [PATCH 133/178] add more relevant fields to caching conds --- modules/processing.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f01a6907..f68e010d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -330,8 +330,21 @@ class StableDiffusionProcessing: caches is a list with items described above. """ + + cached_params = ( + required_prompts, + steps, + opts.CLIP_stop_at_last_layers, + shared.sd_model.sd_checkpoint_info, + extra_network_data, + opts.sdxl_crop_left, + opts.sdxl_crop_top, + self.width, + self.height, + ) + for cache in caches: - if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]: + if cache[0] is not None and cached_params == cache[0]: return cache[1] cache = caches[0] @@ -339,7 +352,7 @@ class StableDiffusionProcessing: with devices.autocast(): cache[1] = function(shared.sd_model, required_prompts, steps) - cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) + cache[0] = cached_params return cache[1] def setup_conds(self): From ac2d47ff4c00b041cae3d882c2832662c2c64935 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 20:27:41 +0300 Subject: [PATCH 134/178] add cheap VAE approximation coeffs for SDXL --- modules/sd_vae_approx.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index b348f3ae..86bd658a 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -64,12 +64,22 @@ def model(): def cheap_approximation(sample): # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 - coefs = torch.tensor([ - [0.298, 0.207, 0.208], - [0.187, 0.286, 0.173], - [-0.158, 0.189, 0.264], - [-0.184, -0.271, -0.473], - ]).to(sample.device) + if shared.sd_model.is_sdxl: + coeffs = [ + [ 0.3448, 0.4168, 0.4395], + [-0.1953, -0.0290, 0.0250], + [ 0.1074, 0.0886, -0.0163], + [-0.3730, -0.2499, -0.2088], + ] + else: + coeffs = [ + [ 0.298, 0.207, 0.208], + [ 0.187, 0.286, 0.173], + [-0.158, 0.189, 0.264], + [-0.184, -0.271, -0.473], + ] + + coefs = torch.tensor(coeffs).to(sample.device) x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs) From 5dee0fa1f812cf9f5fa6675c22c9a57afad39983 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 21:41:21 +0300 Subject: [PATCH 135/178] add a message about unsupported samplers --- modules/sd_samplers.py | 3 +++ modules/sd_samplers_compvis.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index f22aad8f..bea2684c 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -28,6 +28,9 @@ def create_sampler(name, model): assert config is not None, f'bad sampler name: {name}' + if model.is_sdxl and config.options.get("no_sdxl", False): + raise Exception(f"Sampler {config.name} is not supported for SDXL") + sampler = config.constructor(model) sampler.config = config diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index bdae8b40..4a8396f9 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -11,9 +11,9 @@ import modules.models.diffusion.uni_pc samplers_data_compvis = [ - sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}), - sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), - sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}), + sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}), + sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}), + sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}), ] From 95ee0cb18817df3c4fae2e7ba7063b79b0c60b9c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 22:51:58 +0300 Subject: [PATCH 136/178] restyle time taken/VRAM display --- javascript/hints.js | 2 -- modules/call_queue.py | 18 +++++++++++++----- style.css | 17 ++++++++++++++--- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/javascript/hints.js b/javascript/hints.js index dc75ce31..41201b2f 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -84,8 +84,6 @@ var titles = { "Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.", "Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.", - "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).", - "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.", diff --git a/modules/call_queue.py b/modules/call_queue.py index 3b94f8a4..61aa240f 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -85,9 +85,9 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): elapsed = time.perf_counter() - t elapsed_m = int(elapsed // 60) elapsed_s = elapsed % 60 - elapsed_text = f"{elapsed_s:.2f}s" + elapsed_text = f"{elapsed_s:.1f} sec." if elapsed_m > 0: - elapsed_text = f"{elapsed_m}m "+elapsed_text + elapsed_text = f"{elapsed_m} min. "+elapsed_text if run_memmon: mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} @@ -95,14 +95,22 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): reserved_peak = mem_stats['reserved_peak'] sys_peak = mem_stats['system_peak'] sys_total = mem_stats['total'] - sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2) + sys_pct = sys_peak/max(sys_total, 1) * 100 - vram_html = f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" + toltip_a = "Active: peak amount of video memory used during generation (excluding cached data)" + toltip_r = "Reserved: total amout of video memory allocated by the Torch library " + toltip_sys = "System: peak amout of video memory allocated by all running programs, out of total capacity" + + text_a = f"A: {active_peak/1024:.2f} GB" + text_r = f"R: {reserved_peak/1024:.2f} GB" + text_sys = f"Sys: {sys_peak/1024:.1f}/{sys_total/1024:g} GB ({sys_pct:.1f}%)" + + vram_html = f"

{text_a}, {text_r}, {text_sys}

" else: vram_html = '' # last item is always HTML - res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" + res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" return tuple(res) diff --git a/style.css b/style.css index 5073f0f0..27ea6467 100644 --- a/style.css +++ b/style.css @@ -230,17 +230,28 @@ button.custom-button{ .performance { font-size: 0.85em; color: #444; + display: flex; } .performance p{ display: inline-block; } -.performance .time { - margin-right: 0; +.performance p.time, .performance p.vram, .performance p.time abbr, .performance p.vram abbr { + margin-bottom: 0; + color: var(--block-title-text-color); } -.performance .vram { +.performance p.time { +} + +.performance p.vram { + margin-left: auto; +} + +.performance .measurement{ + color: var(--body-text-color); + font-weight: bold; } #txt2img_generate, #img2img_generate { From 5d94088eac401545286ef8b16455cc88e9797300 Mon Sep 17 00:00:00 2001 From: Marcus Adams Date: Fri, 14 Jul 2023 21:52:00 -0400 Subject: [PATCH 137/178] Added [none] filename token. --- modules/images.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/images.py b/modules/images.py index 4bdedb7f..fb5d2e75 100644 --- a/modules/images.py +++ b/modules/images.py @@ -380,6 +380,7 @@ class FilenameGenerator: 'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT, 'user': lambda self: self.p.user, 'vae_filename': lambda self: self.get_vae_filename(), + 'none': lambda self: '', # Overrides the default so you can get just the sequence number } default_time_format = '%Y%m%d%H%M%S' @@ -601,13 +602,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i else: file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]" + file_decoration = namegen.apply(file_decoration) + suffix + add_number = opts.save_images_add_number or file_decoration == '' if file_decoration != "" and add_number: file_decoration = f"-{file_decoration}" - file_decoration = namegen.apply(file_decoration) + suffix - if add_number: basecount = get_next_sequence_number(path, basename) fullfn = None From 14cf434bc36d0ef31f31d4c6cd2bd15d7857d5c8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 07:33:16 +0300 Subject: [PATCH 138/178] fix an issue in live previews that happens when you use SDXL with fp16 VAE --- modules/processing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f68e010d..eb4a60eb 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -539,8 +539,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see def decode_first_stage(model, x): - with devices.autocast(disable=x.dtype == devices.dtype_vae): - x = model.decode_first_stage(x) + x = model.decode_first_stage(x.to(devices.dtype_vae)) return x From b8bd8ce4cf687e9e02000387adf4d751b22a4a36 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 07:44:37 +0300 Subject: [PATCH 139/178] disable rich exception output in console for API by default, use WEBUI_RICH_EXCEPTIONS env var to enable --- modules/api/api.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 11045292..2a4cd8a2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,5 +1,6 @@ import base64 import io +import os import time import datetime import uvicorn @@ -98,14 +99,16 @@ def encode_pil_to_base64(image): def api_middleware(app: FastAPI): - rich_available = True + rich_available = False try: - import anyio # importing just so it can be placed on silent list - import starlette # importing just so it can be placed on silent list - from rich.console import Console - console = Console() + if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None: + import anyio # importing just so it can be placed on silent list + import starlette # importing just so it can be placed on silent list + from rich.console import Console + console = Console() + rich_available = True except Exception: - rich_available = False + pass @app.middleware("http") async def log_and_time(req: Request, call_next): @@ -116,14 +119,14 @@ def api_middleware(app: FastAPI): endpoint = req.scope.get('path', 'err') if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'): print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( - t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), - code = res.status_code, - ver = req.scope.get('http_version', '0.0'), - cli = req.scope.get('client', ('0:0.0.0', 0))[0], - prot = req.scope.get('scheme', 'err'), - method = req.scope.get('method', 'err'), - endpoint = endpoint, - duration = duration, + t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + code=res.status_code, + ver=req.scope.get('http_version', '0.0'), + cli=req.scope.get('client', ('0:0.0.0', 0))[0], + prot=req.scope.get('scheme', 'err'), + method=req.scope.get('method', 'err'), + endpoint=endpoint, + duration=duration, )) return res @@ -134,7 +137,7 @@ def api_middleware(app: FastAPI): "body": vars(e).get('body', ''), "errors": str(e), } - if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions + if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions message = f"API error: {request.method}: {request.url} {err}" if rich_available: print(message) From 127635409a7959f6c057a68ccb8e70734cbaf9f3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 08:07:25 +0300 Subject: [PATCH 140/178] add padding and identification to generation log section (Failed to find Loras, Used embeddings, etc...) --- modules/img2img.py | 2 +- modules/txt2img.py | 2 +- modules/ui.py | 3 +-- modules/ui_common.py | 9 +++++---- style.css | 16 ++++++++++------ 5 files changed, 18 insertions(+), 14 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 664e2688..a811e7a4 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -240,4 +240,4 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if opts.do_not_show_images: processed.images = [] - return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments) + return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments") diff --git a/modules/txt2img.py b/modules/txt2img.py index d0be2e73..29d94e8c 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -70,4 +70,4 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step if opts.do_not_show_images: processed.images = [] - return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments) + return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments") diff --git a/modules/ui.py b/modules/ui.py index 39d226ad..07ecee7b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -83,8 +83,7 @@ detect_image_size_symbol = '\U0001F4D0' # 📐 up_down_symbol = '\u2195\ufe0f' # ↕️ -def plaintext_to_html(text): - return ui_common.plaintext_to_html(text) +plaintext_to_html = ui_common.plaintext_to_html def send_gradio_gallery_to_image(x): diff --git a/modules/ui_common.py b/modules/ui_common.py index 57c2d0ad..11eb2a4b 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -29,9 +29,10 @@ def update_generation_info(generation_info, html_info, img_index): return html_info, gr.update() -def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text +def plaintext_to_html(text, classname=None): + content = "
\n".join(html.escape(x) for x in text.split('\n')) + + return f"

{content}

" if classname else f"

{content}

" def save_files(js_data, images, do_make_zip, index): @@ -157,7 +158,7 @@ Requested path was: {f} with gr.Group(): html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") - html_log = gr.HTML(elem_id=f'html_log_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log") generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') if tabname == 'txt2img' or tabname == 'img2img': diff --git a/style.css b/style.css index 27ea6467..a424067f 100644 --- a/style.css +++ b/style.css @@ -227,29 +227,33 @@ button.custom-button{ align-self: end; } -.performance { +.html-log .comments{ + padding-top: 0.5em; +} + +.html-log .performance { font-size: 0.85em; color: #444; display: flex; } -.performance p{ +.html-log .performance p{ display: inline-block; } -.performance p.time, .performance p.vram, .performance p.time abbr, .performance p.vram abbr { +.html-log .performance p.time, .performance p.vram, .performance p.time abbr, .performance p.vram abbr { margin-bottom: 0; color: var(--block-title-text-color); } -.performance p.time { +.html-log .performance p.time { } -.performance p.vram { +.html-log .performance p.vram { margin-left: auto; } -.performance .measurement{ +.html-log .performance .measurement{ color: var(--body-text-color); font-weight: bold; } From 2b1bae0d755c2d5201f6a6aadeadb5588208d43f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 08:41:22 +0300 Subject: [PATCH 141/178] add textual inversion hashes to infotext --- modules/processing.py | 7 ++++--- modules/sd_hijack.py | 5 ++++- modules/sd_hijack_clip.py | 15 ++++++++++++--- modules/shared.py | 1 + modules/textual_inversion/textual_inversion.py | 9 ++++++++- style.css | 4 ++++ 6 files changed, 33 insertions(+), 8 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index cd568a20..49441e77 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -732,9 +732,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.setup_conds() - if len(model_hijack.comments) > 0: - for comment in model_hijack.comments: - comments[comment] = 1 + for comment in model_hijack.comments: + comments[comment] = 1 + + p.extra_generation_params.update(model_hijack.extra_generation_params) if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3b6f95ce..6b5aae4b 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -147,7 +147,6 @@ def undo_weighted_forward(sd_model): class StableDiffusionModelHijack: fixes = None - comments = [] layers = None circular_enabled = False clip = None @@ -156,6 +155,9 @@ class StableDiffusionModelHijack: embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() def __init__(self): + self.extra_generation_params = {} + self.comments = [] + self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) def apply_optimizations(self, option=None): @@ -236,6 +238,7 @@ class StableDiffusionModelHijack: def clear_comments(self): self.comments = [] + self.extra_generation_params = {} def get_prompt_lengths(self, text): if self.clip is None: diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 3b5a7666..c1d780a3 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -229,9 +229,18 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): z = self.process_tokens(tokens, multipliers) zs.append(z) - if len(used_embeddings) > 0: - embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()]) - self.hijack.comments.append(f"Used embeddings: {embeddings_list}") + if opts.textual_inversion_add_hashes_to_infotext and used_embeddings: + hashes = [] + for name, embedding in used_embeddings.items(): + shorthash = embedding.shorthash + if not shorthash: + continue + + name = name.replace(":", "").replace(",", "") + hashes.append(f"{name}: {shorthash}") + + if hashes: + self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes) return torch.hstack(zs) diff --git a/modules/shared.py b/modules/shared.py index 48478a68..a32fd4ed 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -472,6 +472,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), { "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(), + "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"), "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks), })) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index cbe975b7..38e072a8 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -13,7 +13,7 @@ import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter -from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors +from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -49,6 +49,8 @@ class Embedding: self.sd_checkpoint_name = None self.optimizer_state_dict = None self.filename = None + self.hash = None + self.shorthash = None def save(self, filename): embedding_data = { @@ -82,6 +84,10 @@ class Embedding: self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' return self.cached_checksum + def set_hash(self, v): + self.hash = v + self.shorthash = self.hash[0:12] + class DirWithTextualInversionEmbeddings: def __init__(self, path): @@ -199,6 +205,7 @@ class EmbeddingDatabase: embedding.vectors = vec.shape[0] embedding.shape = vec.shape[-1] embedding.filename = path + embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '') if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) diff --git a/style.css b/style.css index a424067f..9e13d7fd 100644 --- a/style.css +++ b/style.css @@ -231,6 +231,10 @@ button.custom-button{ padding-top: 0.5em; } +.html-log .comments:empty{ + padding-top: 0; +} + .html-log .performance { font-size: 0.85em; color: #444; From 510e5fc8c60dd6278d0bc52effc23257c717dc1b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 09:20:43 +0300 Subject: [PATCH 142/178] cache git extension repo information --- modules/cache.py | 96 ++++++++++++++++++++++++++++++++++++++++ modules/extensions.py | 26 ++++++++--- modules/hashes.py | 38 ++-------------- modules/ui_extensions.py | 6 --- 4 files changed, 119 insertions(+), 47 deletions(-) create mode 100644 modules/cache.py diff --git a/modules/cache.py b/modules/cache.py new file mode 100644 index 00000000..4c2db604 --- /dev/null +++ b/modules/cache.py @@ -0,0 +1,96 @@ +import json +import os.path + +import filelock + +from modules.paths import data_path, script_path + +cache_filename = os.path.join(data_path, "cache.json") +cache_data = None + + +def dump_cache(): + """ + Saves all cache data to a file. + """ + + with filelock.FileLock(f"{cache_filename}.lock"): + with open(cache_filename, "w", encoding="utf8") as file: + json.dump(cache_data, file, indent=4) + + +def cache(subsection): + """ + Retrieves or initializes a cache for a specific subsection. + + Parameters: + subsection (str): The subsection identifier for the cache. + + Returns: + dict: The cache data for the specified subsection. + """ + + global cache_data + + if cache_data is None: + with filelock.FileLock(f"{cache_filename}.lock"): + if not os.path.isfile(cache_filename): + cache_data = {} + else: + try: + with open(cache_filename, "r", encoding="utf8") as file: + cache_data = json.load(file) + except Exception: + os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) + print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') + cache_data = {} + + s = cache_data.get(subsection, {}) + cache_data[subsection] = s + + return s + + +def cached_data_for_file(subsection, title, filename, func): + """ + Retrieves or generates data for a specific file, using a caching mechanism. + + Parameters: + subsection (str): The subsection of the cache to use. + title (str): The title of the data entry in the subsection of the cache. + filename (str): The path to the file to be checked for modifications. + func (callable): A function that generates the data if it is not available in the cache. + + Returns: + dict or None: The cached or generated data, or None if data generation fails. + + The `cached_data_for_file` function implements a caching mechanism for data stored in files. + It checks if the data associated with the given `title` is present in the cache and compares the + modification time of the file with the cached modification time. If the file has been modified, + the cache is considered invalid and the data is regenerated using the provided `func`. + Otherwise, the cached data is returned. + + If the data generation fails, None is returned to indicate the failure. Otherwise, the generated + or cached data is returned as a dictionary. + """ + + existing_cache = cache(subsection) + ondisk_mtime = os.path.getmtime(filename) + + entry = existing_cache.get(title) + if entry: + cached_mtime = existing_cache[title].get("mtime", 0) + if ondisk_mtime > cached_mtime: + entry = None + + if not entry: + entry = func() + if entry is None: + return None + + entry['mtime'] = ondisk_mtime + existing_cache[title] = entry + + dump_cache() + + return entry diff --git a/modules/extensions.py b/modules/extensions.py index abc6e2b1..c561159a 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,7 +1,7 @@ import os import threading -from modules import shared, errors +from modules import shared, errors, cache from modules.gitpython_hack import Repo from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 @@ -21,6 +21,7 @@ def active(): class Extension: lock = threading.Lock() + cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version'] def __init__(self, name, path, enabled=True, is_builtin=False): self.name = name @@ -36,15 +37,29 @@ class Extension: self.remote = None self.have_info_from_repo = False + def to_dict(self): + return {x: getattr(self, x) for x in self.cached_fields} + + def from_dict(self, d): + for field in self.cached_fields: + setattr(self, field, d[field]) + def read_info_from_repo(self): if self.is_builtin or self.have_info_from_repo: return - with self.lock: - if self.have_info_from_repo: - return + def read_from_repo(): + with self.lock: + if self.have_info_from_repo: + return - self.do_read_info_from_repo() + self.do_read_info_from_repo() + + return self.to_dict() + + d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo) + self.from_dict(d) + self.status = 'unknown' def do_read_info_from_repo(self): repo = None @@ -58,7 +73,6 @@ class Extension: self.remote = None else: try: - self.status = 'unknown' self.remote = next(repo.remote().urls, None) commit = repo.head.commit self.commit_date = commit.committed_date diff --git a/modules/hashes.py b/modules/hashes.py index ec1187fe..b7a33b42 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -1,43 +1,11 @@ import hashlib -import json import os.path -import filelock - from modules import shared -from modules.paths import data_path, script_path +import modules.cache - -cache_filename = os.path.join(data_path, "cache.json") -cache_data = None - - -def dump_cache(): - with filelock.FileLock(f"{cache_filename}.lock"): - with open(cache_filename, "w", encoding="utf8") as file: - json.dump(cache_data, file, indent=4) - - -def cache(subsection): - global cache_data - - if cache_data is None: - with filelock.FileLock(f"{cache_filename}.lock"): - if not os.path.isfile(cache_filename): - cache_data = {} - else: - try: - with open(cache_filename, "r", encoding="utf8") as file: - cache_data = json.load(file) - except Exception: - os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) - print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') - cache_data = {} - - s = cache_data.get(subsection, {}) - cache_data[subsection] = s - - return s +dump_cache = modules.cache.dump_cache +cache = modules.cache.cache def calculate_sha256(filename): diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index dff522ef..3fa3dea2 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -513,14 +513,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" def preload_extensions_git_metadata(): - t0 = time.time() for extension in extensions.extensions: extension.read_info_from_repo() - print( - f"preload_extensions_git_metadata for " - f"{len(extensions.extensions)} extensions took " - f"{time.time() - t0:.2f}s" - ) def create_ui(): From 0aa8d538e147ba87df36d8196845807c8fa3f4e1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 09:24:22 +0300 Subject: [PATCH 143/178] suppress printing TI embedding into console by default --- modules/shared.py | 1 + modules/textual_inversion/textual_inversion.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index a32fd4ed..427dcc50 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -472,6 +472,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), { "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(), + "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"), "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks), })) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 38e072a8..6166c76f 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -256,7 +256,7 @@ class EmbeddingDatabase: self.word_embeddings.update(sorted_word_embeddings) displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys())) - if self.previously_displayed_embeddings != displayed_embeddings: + if shared.opts.textual_inversion_print_at_load and self.previously_displayed_embeddings != displayed_embeddings: self.previously_displayed_embeddings = displayed_embeddings print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") if self.skipped_embeddings: From c58cf73c806f08eb8b96bccc2af64403d903695f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 09:33:21 +0300 Subject: [PATCH 144/178] remove "## " from changelog.md version --- modules/launch_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 0e0dbca4..ff77cbfd 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -69,10 +69,12 @@ def git_tag(): return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip() except Exception: try: - from pathlib import Path - changelog_md = Path(__file__).parent.parent / "CHANGELOG.md" - with changelog_md.open(encoding="utf-8") as file: - return next((line.strip() for line in file if line.strip()), "") + + changelog_md = os.path.join(os.path.dirname(os.path.dirname(__file__)), "CHANGELOG.md") + with open(changelog_md, "r", encoding="utf-8") as file: + line = next((line.strip() for line in file if line.strip()), "") + line = line.replace("## ", "") + return line except Exception: return "" From 2d9d53be21d339a0723276517fff067db0181af5 Mon Sep 17 00:00:00 2001 From: Jabasukuriputo Wang Date: Sat, 15 Jul 2023 17:09:51 +0800 Subject: [PATCH 145/178] allow replacing extensions index with environment variable --- modules/ui_extensions.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 3fa3dea2..f3e4fba7 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -1,5 +1,5 @@ import json -import os.path +import os import threading import time from datetime import datetime @@ -564,7 +564,8 @@ def create_ui(): with gr.TabItem("Available", id="available"): with gr.Row(): refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary") - available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False) + extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json") + available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL").style(container=False) extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) From 2970d712ee52eaffee36d0e86cc7def71393a9b5 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 16 Jul 2023 00:59:31 +0900 Subject: [PATCH 146/178] Warns merge into master --- .github/workflows/warns_merge_master.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/workflows/warns_merge_master.yml diff --git a/.github/workflows/warns_merge_master.yml b/.github/workflows/warns_merge_master.yml new file mode 100644 index 00000000..ae2aab6b --- /dev/null +++ b/.github/workflows/warns_merge_master.yml @@ -0,0 +1,19 @@ +name: Pull requests can't target master branch + +"on": + pull_request: + types: + - opened + - synchronize + - reopened + branches: + - master + +jobs: + check: + runs-on: ubuntu-latest + steps: + - name: Warning marge into master + run: | + echo -e "::warning::This pull request directly merge into \"master\" branch, normally development happens on \"dev\" branch." + exit 1 From e5d3ae2bf4e9d39c35e6edc96d6449fd42528e55 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 20:39:04 +0300 Subject: [PATCH 147/178] user metadata system for custom networks --- .../Lora/ui_extra_networks_lora.py | 2 +- html/extra-networks-card.html | 8 +- javascript/extraNetworks.js | 37 +++- modules/ui_extra_networks.py | 54 +++++- modules/ui_extra_networks_checkpoints.py | 2 +- modules/ui_extra_networks_hypernets.py | 6 +- modules/ui_extra_networks_user_metadata.py | 169 ++++++++++++++++++ style.css | 56 ++++-- 8 files changed, 300 insertions(+), 34 deletions(-) create mode 100644 modules/ui_extra_networks_user_metadata.py diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index da49790b..29b16c1c 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -20,7 +20,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): yield { "name": name, - "filename": path, + "filename": lora_on_disk.filename, "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(lora_on_disk.filename), diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index 68a84c3a..fb787ffe 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,11 +1,11 @@
{background_image} - {metadata_button} +
+ {edit_button} + {metadata_button} +
-
{name} diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index b87bca3e..68f342de 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -182,19 +182,20 @@ function extraNetworksSearchButton(tabs_id, event) { var globalPopup = null; var globalPopupInner = null; +function closePopup(){ + if (!globalPopup) return; + + globalPopup.style.display = "none"; +} function popup(contents) { if (!globalPopup) { globalPopup = document.createElement('div'); - globalPopup.onclick = function() { - globalPopup.style.display = "none"; - }; + globalPopup.onclick = closePopup; globalPopup.classList.add('global-popup'); var close = document.createElement('div'); close.classList.add('global-popup-close'); - close.onclick = function() { - globalPopup.style.display = "none"; - }; + close.onclick = closePopup; close.title = "Close"; globalPopup.appendChild(close); @@ -263,3 +264,27 @@ function extraNetworksRequestMetadata(event, extraPage, cardName) { event.stopPropagation(); } + +extraPageUserMetadataEditors = {} + +function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) { + var id = tabname + '_' + extraPage + '_edit_user_metadata'; + + editor = extraPageUserMetadataEditors[id] + if(! editor){ + editor = {}; + editor.page = gradioApp().getElementById(id); + editor.nameTextarea = gradioApp().querySelector("#" + id + "_name" + ' textarea'); + editor.button = gradioApp().querySelector("#" + id + "_button"); + extraPageUserMetadataEditors[id] = editor; + } + + editor.nameTextarea.value = cardName; + updateInput(editor.nameTextarea); + + editor.button.click(); + + popup(editor.page); + + event.stopPropagation(); +} diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 693cafb6..eaae6217 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -2,7 +2,7 @@ import os.path import urllib.parse from pathlib import Path -from modules import shared +from modules import shared, ui_extra_networks_user_metadata, errors from modules.images import read_info_from_image, save_image_with_geninfo from modules.ui import up_down_symbol import gradio as gr @@ -60,13 +60,34 @@ class ExtraNetworksPage: def __init__(self, title): self.title = title self.name = title.lower() + self.id_page = self.name.replace(" ", "_") self.card_page = shared.html("extra-networks-card.html") self.allow_negative_prompt = False self.metadata = {} + self.items = {} def refresh(self): pass + def read_user_metadata(self, item): + filename = item.get("filename", None) + basename, ext = os.path.splitext(filename) + metadata_filename = basename + '.json' + + metadata = {} + try: + if os.path.isfile(metadata_filename): + with open(metadata_filename, "r", encoding="utf8") as file: + metadata = json.load(file) + except Exception as e: + errors.display(e, f"reading extra network user metadata from {metadata_filename}") + + desc = metadata.get("description", None) + if desc is not None: + item["description"] = desc + + item["user_metadata"] = metadata + def link_preview(self, filename): quoted_filename = urllib.parse.quote(filename.replace('\\', '/')) mtime = os.path.getmtime(filename) @@ -119,11 +140,15 @@ class ExtraNetworksPage: """ for subdir in subdirs]) - for item in self.list_items(): + self.items = {x["name"]: x for x in self.list_items()} + for item in self.items.values(): metadata = item.get("metadata") if metadata: self.metadata[item["name"]] = metadata + if "user_metadata" not in item: + self.read_user_metadata(item) + items_html += self.create_html_for_item(item, tabname) if items_html == '': @@ -166,7 +191,9 @@ class ExtraNetworksPage: metadata_button = "" metadata = item.get("metadata") if metadata: - metadata_button = f"" + metadata_button = f"" + + edit_button = f"
" local_path = "" filename = item.get("filename", "") @@ -200,6 +227,7 @@ class ExtraNetworksPage: "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', "search_term": item.get("search_term", ""), "metadata_button": metadata_button, + "edit_button": edit_button, "search_only": " search_only" if search_only else "", "sort_keys": sort_keys, } @@ -247,6 +275,9 @@ class ExtraNetworksPage: pass return None + def create_user_metadata_editor(self, ui, tabname): + return ui_extra_networks_user_metadata.UserMetadataEditor(ui, tabname, self) + def initialize(): extra_pages.clear() @@ -297,20 +328,23 @@ def create_ui(container, button, tabname): ui = ExtraNetworksUi() ui.pages = [] ui.pages_contents = [] + ui.user_metadata_editors = [] ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy()) ui.tabname = tabname with gr.Tabs(elem_id=tabname+"_extra_tabs"): for page in ui.stored_extra_pages: - page_id = page.title.lower().replace(" ", "_") - - with gr.Tab(page.title, id=page_id): - elem_id = f"{tabname}_{page_id}_cards_html" + with gr.Tab(page.title, id=page.id_page): + elem_id = f"{tabname}_{page.id_page}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[]) + editor = page.create_user_metadata_editor(ui, tabname) + editor.create_ui() + ui.user_metadata_editors.append(editor) + gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True) gr.Button(up_down_symbol, elem_id=tabname+"_extra_sortorder") @@ -363,6 +397,8 @@ def path_is_parent(parent_path, child_path): def setup_ui(ui, gallery): def save_preview(index, images, filename): + # this function is here for backwards compatibility and likely will be removed soon + if len(images) == 0: print("There is no image in gallery to save as a preview.") return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] @@ -394,3 +430,7 @@ def setup_ui(ui, gallery): outputs=[*ui.pages] ) + for editor in ui.user_metadata_editors: + editor.setup_ui(gallery) + + diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 8b9ab71b..bb5071e6 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -18,7 +18,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): path, ext = os.path.splitext(checkpoint.filename) yield { "name": checkpoint.name_for_extra, - "filename": path, + "filename": checkpoint.filename, "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 7c19b532..ea0b7a44 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -12,12 +12,12 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): shared.reload_hypernetworks() def list_items(self): - for index, (name, path) in enumerate(shared.hypernetworks.items()): - path, ext = os.path.splitext(path) + for index, (name, full_path) in enumerate(shared.hypernetworks.items()): + path, ext = os.path.splitext(full_path) yield { "name": name, - "filename": path, + "filename": full_path, "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(path), diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py new file mode 100644 index 00000000..8d20d026 --- /dev/null +++ b/modules/ui_extra_networks_user_metadata.py @@ -0,0 +1,169 @@ +import datetime +import html +import json +import os.path + +import gradio as gr + +from modules import generation_parameters_copypaste, images, sysinfo, errors + + +class UserMetadataEditor: + + def __init__(self, ui, tabname, page): + self.ui = ui + self.tabname = tabname + self.page = page + self.id_part = f"{self.tabname}_{self.page.id_page}_edit_user_metadata" + + self.box = None + + self.edit_name_input = None + self.button_edit = None + + self.edit_name = None + self.edit_description = None + self.html_filedata = None + self.html_preview = None + + self.button_cancel = None + self.button_replace_preview = None + self.button_save = None + + def get_user_metadata(self, name): + item = self.page.items.get(name, {}) + + user_metadata = item.get('user_metadata', None) + if user_metadata is None: + user_metadata = {} + item['user_metadata'] = user_metadata + + return user_metadata + + def create_default_editor_elems(self): + with gr.Row(): + with gr.Column(scale=2): + self.edit_name = gr.HTML(elem_classes="extra-network-name") + self.edit_description = gr.Textbox(label="Description", lines=4) + self.html_filedata = gr.HTML() + + with gr.Column(scale=1, min_width=0): + self.html_preview = gr.HTML() + + def create_default_buttons(self): + + with gr.Row(): + self.button_cancel = gr.Button('Cancel') + self.button_replace_preview = gr.Button('Replace preview', variant='primary') + self.button_save = gr.Button('Save', variant='primary') + + self.button_cancel.click(fn=None, _js="closePopup") + + def get_card_html(self, name): + item = self.page.items.get(name, {}) + + preview_url = item.get("preview", None) + + if not preview_url: + filename, _ = os.path.splitext(item["filename"]) + preview_url = self.page.find_preview(filename) + item["preview"] = preview_url + + if preview_url: + preview = f''' +
+ +
+ ''' + else: + preview = "
" + + return preview + + def get_metadata_table(self, name): + item = self.page.items.get(name, {}) + try: + filename = item["filename"] + + stats = os.stat(filename) + params = [ + ('File size: ', sysinfo.pretty_bytes(stats.st_size)), + ('Created: ', datetime.datetime.fromtimestamp(stats.st_ctime).strftime('%Y-%m-%d %H:%M')), + ('Last modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')), + ] + + return params + except Exception as e: + errors.display(e, f"reading info for {name}") + return [] + + def put_values_into_components(self, name): + user_metadata = self.get_user_metadata(name) + + params = self.get_metadata_table(name) + table = '
Extension + + Extension + URL Branch Version
{html.escape(ext.name)}{html.escape(ext.name)} {remote} {ext.branch} {version_link}
{html.escape(name)}
{tags_text}
{html.escape(description)}

Added: {html.escape(added)}

{html.escape(description)}

Added: {html.escape(added)}stars: {stars:,}

{install_code}
{html.escape(name)}
{tags_text}
{html.escape(description)}

Added: {html.escape(added)}stars: {stars:,}

{html.escape(description)}

+ Update: {html.escape(update_time)} Added: {html.escape(added)} Created: {html.escape(create_time)}stars: {stars}

{install_code}
' + "".join(f"" for name, value in params) + '' + + return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name) + + def write_user_metadata(self, name, metadata): + item = self.page.items.get(name, {}) + filename = item.get("filename", None) + basename, ext = os.path.splitext(filename) + + with open(basename + '.json', "w", encoding="utf8") as file: + json.dump(metadata, file) + + def save_user_metadata(self, name, desc): + user_metadata = self.get_user_metadata(name) + user_metadata["description"] = desc + + self.write_user_metadata(name, user_metadata) + + def create_editor(self): + self.create_default_editor_elems() + + self.create_default_buttons() + + self.button_edit\ + .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview])\ + .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) + + self.button_save.click(fn=self.save_user_metadata, inputs=[self.edit_name_input, self.edit_description], outputs=[]).then(fn=None, _js="closePopup") + + def create_ui(self): + with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box: + self.box = box + + self.edit_name_input = gr.Textbox("Edit user metadata card id", visible=False, elem_id=f"{self.id_part}_name") + self.button_edit = gr.Button("Edit user metadata", visible=False, elem_id=f"{self.id_part}_button") + + self.create_editor() + + def save_preview(self, index, gallery, name): + if len(gallery) == 0: + print("There is no image in gallery to save as a preview.") + return [self.get_card_html(name)] + [page.create_html(self.ui.tabname) for page in self.ui.stored_extra_pages] + + item = self.page.items.get(name, {}) + + index = int(index) + index = 0 if index < 0 else index + index = len(gallery) - 1 if index >= len(gallery) else index + + img_info = gallery[index if index >= 0 else 0] + image = generation_parameters_copypaste.image_from_url_text(img_info) + geninfo, items = images.read_info_from_image(image) + + images.save_image_with_geninfo(image, geninfo, item["local_preview"]) + + return [self.get_card_html(name)] + [page.create_html(self.tabname) for page in self.ui.stored_extra_pages] + + def setup_ui(self, gallery): + self.button_replace_preview.click( + fn=self.save_preview, + _js="function(x, y, z){return [selected_gallery_index(), y, z]}", + inputs=[self.edit_name_input, gallery, self.edit_name_input], + outputs=[self.html_preview, *self.ui.pages] + ) + + diff --git a/style.css b/style.css index 9e13d7fd..4431c1aa 100644 --- a/style.css +++ b/style.css @@ -550,6 +550,9 @@ table.popup-table .link{ background-color: rgba(20, 20, 20, 0.95); } +.global-popup *{ + box-sizing: border-box; +} .global-popup-close:before { content: "×"; @@ -815,32 +818,42 @@ footer { } -.extra-network-cards .card .metadata-button:before, .extra-network-thumbs .card .metadata-button:before{ - content: "🛈"; -} -.extra-network-cards .card .metadata-button, .extra-network-thumbs .card .metadata-button{ +.extra-network-cards .card .button-row, .extra-network-thumbs .card .button-row{ display: none; position: absolute; color: white; right: 0; } -.extra-network-cards .card .metadata-button { +.extra-network-cards .card:hover .button-row, .extra-network-thumbs .card:hover .button-row{ + display: flex; +} + +.extra-network-cards .card .card-button, .extra-network-thumbs .card .card-button{ + color: white; +} + +.extra-network-cards .card .metadata-button:before, .extra-network-thumbs .card .metadata-button:before{ + content: "🛈"; +} + +.extra-network-cards .card .edit-button:before, .extra-network-thumbs .card .edit-button:before{ + content: "🛠"; +} + +.extra-network-cards .card .card-button { text-shadow: 2px 2px 3px black; padding: 0.25em; font-size: 22pt; width: 1.5em; } -.extra-network-thumbs .card .metadata-button { +.extra-network-thumbs .card .card-button { text-shadow: 1px 1px 2px black; padding: 0; font-size: 16pt; width: 1em; top: -0.25em; } -.extra-network-cards .card:hover .metadata-button, .extra-network-thumbs .card:hover .metadata-button{ - display: inline-block; -} -.extra-network-cards .card .metadata-button:hover, .extra-network-thumbs .card .metadata-button:hover{ +.extra-network-cards .card .card-button:hover, .extra-network-thumbs .card .card-button:hover{ color: red; } @@ -861,7 +874,7 @@ footer { position: relative; } -.extra-network-thumbs .card .preview{ +.extra-network-thumbs .card .preview, .standalone-card-preview.card .preview{ position: absolute; object-fit: cover; width: 100%; @@ -905,7 +918,7 @@ footer { word-break: break-all; } -.extra-network-cards .card{ +.extra-network-cards .card, .standalone-card-preview.card{ display: inline-block; margin: 0.5em; width: 16em; @@ -989,3 +1002,22 @@ footer { width: 100%; height:100%; } + +div.block.gradio-box.edit-user-metadata { + min-width: 56em; + background: var(--body-background-fill); + padding: 2em !important; +} + +.edit-user-metadata .extra-network-name{ + font-size: 18pt; + color: var(--body-text-color); +} + +.edit-user-metadata .file-metadata th{ + text-align: left; +} + +.edit-user-metadata .wrap.translucent{ + background: var(--body-background-fill); +} From 5decbf184b185026d5da9e2c7be02d06fd640f12 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 21:05:33 +0300 Subject: [PATCH 148/178] eslint --- javascript/extraNetworks.js | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 68f342de..7007b353 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -182,7 +182,7 @@ function extraNetworksSearchButton(tabs_id, event) { var globalPopup = null; var globalPopupInner = null; -function closePopup(){ +function closePopup() { if (!globalPopup) return; globalPopup.style.display = "none"; @@ -265,13 +265,13 @@ function extraNetworksRequestMetadata(event, extraPage, cardName) { event.stopPropagation(); } -extraPageUserMetadataEditors = {} +var extraPageUserMetadataEditors = {}; function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) { var id = tabname + '_' + extraPage + '_edit_user_metadata'; - editor = extraPageUserMetadataEditors[id] - if(! editor){ + var editor = extraPageUserMetadataEditors[id]; + if (!editor) { editor = {}; editor.page = gradioApp().getElementById(id); editor.nameTextarea = gradioApp().querySelector("#" + id + "_name" + ' textarea'); From 11f339733de860b0b51adebe15dc945df7189edf Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 00:56:53 +0300 Subject: [PATCH 149/178] add lora user metadata editor dialog inspired by MrKuenning's mockup from #7458 --- .../Lora/ui_edit_user_metadata.py | 187 ++++++++++++++++++ .../Lora/ui_extra_networks_lora.py | 17 +- javascript/extraNetworks.js | 18 +- modules/ui_extra_networks_user_metadata.py | 23 ++- style.css | 9 +- 5 files changed, 241 insertions(+), 13 deletions(-) create mode 100644 extensions-builtin/Lora/ui_edit_user_metadata.py diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py new file mode 100644 index 00000000..c7dbd1c1 --- /dev/null +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -0,0 +1,187 @@ +import html +import json +import random + +import gradio as gr +import re + +from modules import ui_extra_networks_user_metadata + + +def is_non_comma_tagset(tags): + average_tag_length = sum(len(x) for x in tags.keys()) / len(tags) + + return average_tag_length >= 16 + + +re_word = re.compile(r"[-_\w']+") +re_comma = re.compile(r" *, *") + + +def build_tags(metadata): + tags = {} + + for _, tags_dict in metadata.get("ss_tag_frequency", {}).items(): + for tag, tag_count in tags_dict.items(): + tag = tag.strip() + tags[tag] = tags.get(tag, 0) + int(tag_count) + + if tags and is_non_comma_tagset(tags): + new_tags = {} + + for text, text_count in tags.items(): + for word in re.findall(re_word, text): + if len(word) < 3: + continue + + new_tags[word] = new_tags.get(word, 0) + text_count + + tags = new_tags + + ordered_tags = sorted(tags.keys(), key=tags.get, reverse=True) + + return [(tag, tags[tag]) for tag in ordered_tags] + + +class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor): + def __init__(self, ui, tabname, page): + super().__init__(ui, tabname, page) + + self.taginfo = None + self.edit_activation_text = None + self.slider_preferred_weight = None + self.edit_notes = None + + def save_lora_user_metadata(self, name, desc, activation_text, preferred_weight, notes): + user_metadata = self.get_user_metadata(name) + user_metadata["description"] = desc + user_metadata["activation text"] = activation_text + user_metadata["preferred weight"] = preferred_weight + user_metadata["notes"] = notes + + self.write_user_metadata(name, user_metadata) + + def get_metadata_table(self, name): + table = super().get_metadata_table(name) + item = self.page.items.get(name, {}) + metadata = json.loads(item.get("metadata") or '{}') + + keys = [ + ('ss_sd_model_name', "Model:"), + ('ss_resolution', "Resolution:"), + ('ss_clip_skip', "Clip skip:"), + ] + + for key, label in keys: + value = metadata.get(key, None) + if value is not None and str(value) != "None": + table.append((label, html.escape(value))) + + image_count = 0 + for _, params in metadata.get("ss_dataset_dirs", {}).items(): + image_count += int(params.get("img_count", 0)) + + if image_count: + table.append(("Dataset size:", image_count)) + + return table + + def put_values_into_components(self, name): + user_metadata = self.get_user_metadata(name) + values = super().put_values_into_components(name) + + item = self.page.items.get(name, {}) + metadata = json.loads(item.get("metadata") or '{}') + + tags = build_tags(metadata) + gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]] + + return [ + *values[0:4], + gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), + user_metadata.get('activation text', ''), + float(user_metadata.get('preferred weight', 0.0)), + user_metadata.get('notes', ''), + gr.update(visible=True if tags else False), + gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), + ] + + def generate_random_prompt(self, name): + item = self.page.items.get(name, {}) + metadata = json.loads(item.get("metadata") or '{}') + tags = build_tags(metadata) + + return self.generate_random_prompt_from_tags(tags) + + def generate_random_prompt_from_tags(self, tags): + max_count = None + res = [] + for tag, count in tags: + if not max_count: + max_count = count + + v = random.random() * max_count + if count > v: + res.append(tag) + + return ", ".join(sorted(res)) + + def create_editor(self): + self.create_default_editor_elems() + + self.taginfo = gr.HighlightedText(label="Tags") + self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") + self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) + + with gr.Row() as row_random_prompt: + with gr.Column(scale=8): + random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) + + with gr.Column(scale=1, min_width=120): + generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg") + + self.edit_notes = gr.TextArea(label='Notes', lines=4) + + generate_random_prompt.click(fn=self.generate_random_prompt, inputs=[self.edit_name_input], outputs=[random_prompt]) + + def select_tag(activation_text, evt: gr.SelectData): + tag = evt.value[0] + + words = re.split(re_comma, activation_text) + if tag in words: + words = [x for x in words if x != tag and x.strip()] + return ", ".join(words) + + return activation_text + ", " + tag if activation_text else tag + + self.taginfo.select(fn=select_tag, inputs=[self.edit_activation_text], outputs=[self.edit_activation_text], show_progress=False) + + self.create_default_buttons() + + viewed_components = [ + self.edit_name, + self.edit_description, + self.html_filedata, + self.html_preview, + self.taginfo, + self.edit_activation_text, + self.slider_preferred_weight, + self.edit_notes, + row_random_prompt, + random_prompt, + ] + + self.button_edit\ + .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\ + .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) + + edited_components = [ + self.edit_description, + self.edit_activation_text, + self.slider_preferred_weight, + self.edit_notes, + ] + + self.button_save\ + .click(fn=self.save_lora_user_metadata, inputs=[self.edit_name_input, *edited_components], outputs=[]) \ + .then(fn=None, _js="extraNetworksReloadAll") diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 29b16c1c..95296275 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -3,6 +3,7 @@ import os import lora from modules import shared, ui_extra_networks +from ui_edit_user_metadata import LoraUserMetadataEditor class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): @@ -18,19 +19,29 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): alias = lora_on_disk.get_alias() - yield { + item = { "name": name, "filename": lora_on_disk.filename, "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(lora_on_disk.filename), - "prompt": json.dumps(f""), "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, - } + self.read_user_metadata(item) + activation_text = item["user_metadata"].get("activation text") + preferred_weight = item["user_metadata"].get("preferred weight", 0.0) + item["prompt"] = json.dumps(f"") + + if activation_text: + item["prompt"] += " + " + json.dumps(" " + activation_text) + + yield item + def allowed_directories_for_previews(self): return [shared.cmd_opts.lora_dir] + def create_user_metadata_editor(self, ui, tabname): + return LoraUserMetadataEditor(ui, tabname, self) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 7007b353..8b67bf2b 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -113,7 +113,7 @@ function setupExtraNetworks() { onUiLoaded(setupExtraNetworks); -var re_extranet = /<([^:]+:[^:]+):[\d.]+>/; +var re_extranet = /<([^:]+:[^:]+):[\d.]+>(.*)/; var re_extranet_g = /\s+<([^:]+:[^:]+):[\d.]+>/g; function tryToRemoveExtraNetworkFromPrompt(textarea, text) { @@ -121,15 +121,22 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { var replaced = false; var newTextareaText; if (m) { + var extraTextAfterNet = m[2]; var partToSearch = m[1]; - newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found) { + var foundAtPosition = -1; + newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) { m = found.match(re_extranet); if (m[1] == partToSearch) { replaced = true; + foundAtPosition = pos; return ""; } return found; }); + + if (foundAtPosition >= 0 && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { + newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); + } } else { newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) { if (found == text) { @@ -288,3 +295,10 @@ function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) { event.stopPropagation(); } + +function extraNetworksReloadAll() { + closePopup(); + + gradioApp().getElementById('txt2img_extra_refresh').click(); + gradioApp().getElementById('img2img_extra_refresh').click(); +} diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 8d20d026..0dbd7419 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -52,7 +52,7 @@ class UserMetadataEditor: def create_default_buttons(self): - with gr.Row(): + with gr.Row(elem_classes="edit-user-metadata-buttons"): self.button_cancel = gr.Button('Cancel') self.button_replace_preview = gr.Button('Replace preview', variant='primary') self.button_save = gr.Button('Save', variant='primary') @@ -88,8 +88,7 @@ class UserMetadataEditor: stats = os.stat(filename) params = [ ('File size: ', sysinfo.pretty_bytes(stats.st_size)), - ('Created: ', datetime.datetime.fromtimestamp(stats.st_ctime).strftime('%Y-%m-%d %H:%M')), - ('Last modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')), + ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')), ] return params @@ -100,7 +99,12 @@ class UserMetadataEditor: def put_values_into_components(self, name): user_metadata = self.get_user_metadata(name) - params = self.get_metadata_table(name) + try: + params = self.get_metadata_table(name) + except Exception as e: + errors.display(e, f"reading metadata info for {name}") + params = [] + table = '' + "".join(f"" for name, value in params) + '' return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name) @@ -128,7 +132,9 @@ class UserMetadataEditor: .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview])\ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) - self.button_save.click(fn=self.save_user_metadata, inputs=[self.edit_name_input, self.edit_description], outputs=[]).then(fn=None, _js="closePopup") + self.button_save\ + .click(fn=self.save_user_metadata, inputs=[self.edit_name_input, self.edit_description], outputs=[])\ + .then(fn=None, _js="extraNetworksReloadAll") def create_ui(self): with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box: @@ -142,7 +148,7 @@ class UserMetadataEditor: def save_preview(self, index, gallery, name): if len(gallery) == 0: print("There is no image in gallery to save as a preview.") - return [self.get_card_html(name)] + [page.create_html(self.ui.tabname) for page in self.ui.stored_extra_pages] + return [self.get_card_html(name)] + self.regenerate_ui_pages() item = self.page.items.get(name, {}) @@ -156,7 +162,10 @@ class UserMetadataEditor: images.save_image_with_geninfo(image, geninfo, item["local_preview"]) - return [self.get_card_html(name)] + [page.create_html(self.tabname) for page in self.ui.stored_extra_pages] + return [self.get_card_html(name)] + self.regenerate_ui_pages() + + def regenerate_ui_pages(self): + return [page.create_html(self.tabname) for page in self.ui.stored_extra_pages] def setup_ui(self, gallery): self.button_replace_preview.click( diff --git a/style.css b/style.css index 4431c1aa..af6344a8 100644 --- a/style.css +++ b/style.css @@ -1004,7 +1004,7 @@ footer { } div.block.gradio-box.edit-user-metadata { - min-width: 56em; + width: 56em; background: var(--body-background-fill); padding: 2em !important; } @@ -1021,3 +1021,10 @@ div.block.gradio-box.edit-user-metadata { .edit-user-metadata .wrap.translucent{ background: var(--body-background-fill); } +.edit-user-metadata .gradio-highlightedtext span{ + word-break: break-word; +} + +.edit-user-metadata-buttons{ + margin-top: 1.5em; +} From efceed8c7f99a959bfe1a4d9210f27aac91f7705 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 01:09:19 +0300 Subject: [PATCH 150/178] fix styles for dark people --- style.css | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/style.css b/style.css index af6344a8..4e22cfd6 100644 --- a/style.css +++ b/style.css @@ -1014,6 +1014,10 @@ div.block.gradio-box.edit-user-metadata { color: var(--body-text-color); } +.edit-user-metadata .file-metadata{ + color: var(--body-text-color); +} + .edit-user-metadata .file-metadata th{ text-align: left; } From 8c11b126e5bd5052154c095177390f249e8e9889 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Sat, 15 Jul 2023 23:43:49 -0400 Subject: [PATCH 151/178] 404 when thumb file not found --- modules/ui_extra_networks.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 693cafb6..a2565315 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -8,6 +8,7 @@ from modules.ui import up_down_symbol import gradio as gr import json import html +from fastapi.exceptions import HTTPException from modules.generation_parameters_copypaste import image_from_url_text @@ -26,6 +27,9 @@ def register_page(page): def fetch_file(filename: str = ""): from starlette.responses import FileResponse + if not os.path.isfile(filename): + raise HTTPException(status_code=404, detail="File not found") + if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs): raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") From a1d6ada69ac686a628e79b61b8f86d01592a7209 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 08:38:23 +0300 Subject: [PATCH 152/178] allow refreshing single card after editing user metadata instead of all cards --- .../Lora/ui_edit_user_metadata.py | 4 +- .../Lora/ui_extra_networks_lora.py | 54 ++++++++++--------- html/extra-networks-card.html | 2 +- javascript/extraNetworks.js | 17 ++++-- modules/ui_extra_networks.py | 20 +++++++ modules/ui_extra_networks_checkpoints.py | 31 ++++++----- modules/ui_extra_networks_hypernets.py | 31 ++++++----- .../ui_extra_networks_textual_inversion.py | 30 ++++++----- modules/ui_extra_networks_user_metadata.py | 38 ++++++++----- 9 files changed, 141 insertions(+), 86 deletions(-) diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index c7dbd1c1..2aa65223 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -182,6 +182,4 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.edit_notes, ] - self.button_save\ - .click(fn=self.save_lora_user_metadata, inputs=[self.edit_name_input, *edited_components], outputs=[]) \ - .then(fn=None, _js="extraNetworksReloadAll") + self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 95296275..80e741dc 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -13,31 +13,37 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): def refresh(self): lora.list_available_loras() + def create_item(self, name, index=None): + lora_on_disk = lora.available_loras.get(name) + + path, ext = os.path.splitext(lora_on_disk.filename) + + alias = lora_on_disk.get_alias() + + item = { + "name": name, + "filename": lora_on_disk.filename, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(lora_on_disk.filename), + "local_preview": f"{path}.{shared.opts.samples_format}", + "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, + "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, + } + + self.read_user_metadata(item) + activation_text = item["user_metadata"].get("activation text") + preferred_weight = item["user_metadata"].get("preferred weight", 0.0) + item["prompt"] = json.dumps(f"") + + if activation_text: + item["prompt"] += " + " + json.dumps(" " + activation_text) + + return item + def list_items(self): - for index, (name, lora_on_disk) in enumerate(lora.available_loras.items()): - path, ext = os.path.splitext(lora_on_disk.filename) - - alias = lora_on_disk.get_alias() - - item = { - "name": name, - "filename": lora_on_disk.filename, - "preview": self.find_preview(path), - "description": self.find_description(path), - "search_term": self.search_terms_from_path(lora_on_disk.filename), - "local_preview": f"{path}.{shared.opts.samples_format}", - "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, - "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, - } - - self.read_user_metadata(item) - activation_text = item["user_metadata"].get("activation text") - preferred_weight = item["user_metadata"].get("preferred weight", 0.0) - item["prompt"] = json.dumps(f"") - - if activation_text: - item["prompt"] += " + " + json.dumps(" " + activation_text) - + for index, name in enumerate(lora.available_loras): + item = self.create_item(name, index) yield item def allowed_directories_for_previews(self): diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index fb787ffe..eb8b1a67 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,4 +1,4 @@ -
+
{background_image}
{edit_button} diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 8b67bf2b..e453094a 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -296,9 +296,18 @@ function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) { event.stopPropagation(); } -function extraNetworksReloadAll() { - closePopup(); +function extraNetworksRefreshSingleCard(page, tabname, name) { + requestGet("./sd_extra_networks/get-single-card", {page: page, tabname: tabname, name: name}, function(data) { + if (data && data.html) { + var card = gradioApp().querySelector('.card[data-name=' + JSON.stringify(name) + ']'); // likely using the wrong stringify function - gradioApp().getElementById('txt2img_extra_refresh').click(); - gradioApp().getElementById('img2img_extra_refresh').click(); + var newDiv = document.createElement('DIV'); + newDiv.innerHTML = data.html; + var newCard = newDiv.firstElementChild; + + newCard.style = ''; + card.parentElement.insertBefore(newCard, card); + card.parentElement.removeChild(card); + } + }); } diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index eaae6217..42c4d0ac 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -51,9 +51,26 @@ def get_metadata(page: str = "", item: str = ""): return JSONResponse({"metadata": metadata}) +def get_single_card(page: str = "", tabname: str = "", name: str = ""): + from starlette.responses import JSONResponse + + page = next(iter([x for x in extra_pages if x.name == page]), None) + + try: + item = page.create_item(name) + except Exception as e: + errors.display(e, "creating item for extra network") + item = page.items.get(name) + + item_html = page.create_html_for_item(item, tabname) + + return JSONResponse({"html": item_html}) + + def add_pages_to_demo(app): app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"]) app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"]) + app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"]) class ExtraNetworksPage: @@ -168,6 +185,9 @@ class ExtraNetworksPage: return res + def create_item(self, name, index=None): + raise NotImplementedError() + def list_items(self): raise NotImplementedError() diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index bb5071e6..ef8cdf35 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -12,21 +12,24 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): def refresh(self): shared.refresh_checkpoints() - def list_items(self): - checkpoint: sd_models.CheckpointInfo - for index, (name, checkpoint) in enumerate(sd_models.checkpoints_list.items()): - path, ext = os.path.splitext(checkpoint.filename) - yield { - "name": checkpoint.name_for_extra, - "filename": checkpoint.filename, - "preview": self.find_preview(path), - "description": self.find_description(path), - "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), - "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', - "local_preview": f"{path}.{shared.opts.samples_format}", - "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, + def create_item(self, name, index=None): + checkpoint: sd_models.CheckpointInfo = sd_models.checkpoints_list.get(name) + path, ext = os.path.splitext(checkpoint.filename) + return { + "name": checkpoint.name_for_extra, + "filename": checkpoint.filename, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), + "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', + "local_preview": f"{path}.{shared.opts.samples_format}", + "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, - } + } + + def list_items(self): + for index, name in enumerate(sd_models.checkpoints_list): + yield self.create_item(name, index) def allowed_directories_for_previews(self): return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index ea0b7a44..8dae23c6 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -11,21 +11,24 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): def refresh(self): shared.reload_hypernetworks() + def create_item(self, name, index=None): + full_path = shared.hypernetworks[name] + path, ext = os.path.splitext(full_path) + + return { + "name": name, + "filename": full_path, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(path), + "prompt": json.dumps(f""), + "local_preview": f"{path}.preview.{shared.opts.samples_format}", + "sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, + } + def list_items(self): - for index, (name, full_path) in enumerate(shared.hypernetworks.items()): - path, ext = os.path.splitext(full_path) - - yield { - "name": name, - "filename": full_path, - "preview": self.find_preview(path), - "description": self.find_description(path), - "search_term": self.search_terms_from_path(path), - "prompt": json.dumps(f""), - "local_preview": f"{path}.preview.{shared.opts.samples_format}", - "sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, - - } + for index, name in enumerate(shared.hypernetworks): + yield self.create_item(name, index) def allowed_directories_for_previews(self): return [shared.cmd_opts.hypernetwork_dir] diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 58a61c55..159f2d64 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -12,20 +12,24 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): def refresh(self): sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) - def list_items(self): - for index, embedding in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings.values()): - path, ext = os.path.splitext(embedding.filename) - yield { - "name": embedding.name, - "filename": embedding.filename, - "preview": self.find_preview(path), - "description": self.find_description(path), - "search_term": self.search_terms_from_path(embedding.filename), - "prompt": json.dumps(embedding.name), - "local_preview": f"{path}.preview.{shared.opts.samples_format}", - "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, + def create_item(self, name, index=None): + embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) - } + path, ext = os.path.splitext(embedding.filename) + return { + "name": name, + "filename": embedding.filename, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(embedding.filename), + "prompt": json.dumps(embedding.name), + "local_preview": f"{path}.preview.{shared.opts.samples_format}", + "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, + } + + def list_items(self): + for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings): + yield self.create_item(name, index) def allowed_directories_for_previews(self): return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 0dbd7419..01ff4e4b 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -23,8 +23,10 @@ class UserMetadataEditor: self.edit_name = None self.edit_description = None + self.edit_notes = None self.html_filedata = None self.html_preview = None + self.html_status = None self.button_cancel = None self.button_replace_preview = None @@ -57,6 +59,8 @@ class UserMetadataEditor: self.button_replace_preview = gr.Button('Replace preview', variant='primary') self.button_save = gr.Button('Save', variant='primary') + self.html_status = gr.HTML(elem_classes="edit-user-metadata-status") + self.button_cancel.click(fn=None, _js="closePopup") def get_card_html(self, name): @@ -107,7 +111,7 @@ class UserMetadataEditor: table = '' + "".join(f"" for name, value in params) + '' - return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name) + return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', ''), def write_user_metadata(self, name, metadata): item = self.page.items.get(name, {}) @@ -117,24 +121,30 @@ class UserMetadataEditor: with open(basename + '.json', "w", encoding="utf8") as file: json.dump(metadata, file) - def save_user_metadata(self, name, desc): + def save_user_metadata(self, name, desc, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc + user_metadata["notes"] = notes self.write_user_metadata(name, user_metadata) + def setup_save_handler(self, button, func, components): + button\ + .click(fn=func, inputs=[self.edit_name_input, *components], outputs=[])\ + .then(fn=None, _js="function(name){closePopup(); extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}", inputs=[self.edit_name_input], outputs=[]) + def create_editor(self): self.create_default_editor_elems() + self.edit_notes = gr.TextArea(label='Notes', lines=4) + self.create_default_buttons() self.button_edit\ - .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview])\ + .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview, self.edit_notes])\ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) - self.button_save\ - .click(fn=self.save_user_metadata, inputs=[self.edit_name_input, self.edit_description], outputs=[])\ - .then(fn=None, _js="extraNetworksReloadAll") + self.setup_save_handler(self.button_save, self.save_user_metadata, [self.edit_description, self.edit_notes]) def create_ui(self): with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box: @@ -147,8 +157,7 @@ class UserMetadataEditor: def save_preview(self, index, gallery, name): if len(gallery) == 0: - print("There is no image in gallery to save as a preview.") - return [self.get_card_html(name)] + self.regenerate_ui_pages() + return self.get_card_html(name), "There is no image in gallery to save as a preview." item = self.page.items.get(name, {}) @@ -162,17 +171,20 @@ class UserMetadataEditor: images.save_image_with_geninfo(image, geninfo, item["local_preview"]) - return [self.get_card_html(name)] + self.regenerate_ui_pages() - - def regenerate_ui_pages(self): - return [page.create_html(self.tabname) for page in self.ui.stored_extra_pages] + return self.get_card_html(name), '' def setup_ui(self, gallery): self.button_replace_preview.click( fn=self.save_preview, _js="function(x, y, z){return [selected_gallery_index(), y, z]}", inputs=[self.edit_name_input, gallery, self.edit_name_input], - outputs=[self.html_preview, *self.ui.pages] + outputs=[self.html_preview, self.html_status] + ).then( + fn=None, + _js="function(name){extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}", + inputs=[self.edit_name_input], + outputs=[] ) + From 47d9dd0240872dc70fd26bc1bf309f49fe17c104 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 09:25:32 +0300 Subject: [PATCH 153/178] speedup extra networks listing --- extensions-builtin/Lora/lora.py | 12 ++++++--- .../Lora/ui_edit_user_metadata.py | 9 +++---- .../Lora/ui_extra_networks_lora.py | 9 ++++--- modules/cache.py | 27 ++++++++++--------- modules/ui_extra_networks.py | 20 +++++++++----- modules/ui_extra_networks_checkpoints.py | 4 +-- modules/ui_extra_networks_hypernets.py | 4 +-- .../ui_extra_networks_textual_inversion.py | 4 +-- 8 files changed, 51 insertions(+), 38 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index cd46e6c7..c8710922 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -3,7 +3,7 @@ import re import torch from typing import Union -from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes +from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes, cache metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} @@ -78,9 +78,16 @@ class LoraOnDisk: self.metadata = {} self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" + def read_metadata(): + metadata = sd_models.read_metadata_from_safetensors(filename) + metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text + + return metadata + if self.is_safetensors: try: - self.metadata = sd_models.read_metadata_from_safetensors(filename) + #self.metadata = sd_models.read_metadata_from_safetensors(filename) + self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata) except Exception as e: errors.display(e, f"reading lora {filename}") @@ -91,7 +98,6 @@ class LoraOnDisk: self.metadata = m - self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text self.alias = self.metadata.get('ss_output_name', self.name) self.hash = None diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index 2aa65223..6db63b09 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -1,5 +1,4 @@ import html -import json import random import gradio as gr @@ -64,7 +63,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) def get_metadata_table(self, name): table = super().get_metadata_table(name) item = self.page.items.get(name, {}) - metadata = json.loads(item.get("metadata") or '{}') + metadata = item.get("metadata") or {} keys = [ ('ss_sd_model_name', "Model:"), @@ -91,7 +90,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) values = super().put_values_into_components(name) item = self.page.items.get(name, {}) - metadata = json.loads(item.get("metadata") or '{}') + metadata = item.get("metadata") or {} tags = build_tags(metadata) gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]] @@ -108,7 +107,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) def generate_random_prompt(self, name): item = self.page.items.get(name, {}) - metadata = json.loads(item.get("metadata") or '{}') + metadata = item.get("metadata") or {} tags = build_tags(metadata) return self.generate_random_prompt_from_tags(tags) @@ -142,7 +141,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.edit_notes = gr.TextArea(label='Notes', lines=4) - generate_random_prompt.click(fn=self.generate_random_prompt, inputs=[self.edit_name_input], outputs=[random_prompt]) + generate_random_prompt.click(fn=self.generate_random_prompt, inputs=[self.edit_name_input], outputs=[random_prompt], show_progress=False) def select_tag(activation_text, evt: gr.SelectData): tag = evt.value[0] diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 80e741dc..b2bc1810 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -1,8 +1,8 @@ -import json import os import lora from modules import shared, ui_extra_networks +from modules.ui_extra_networks import quote_js from ui_edit_user_metadata import LoraUserMetadataEditor @@ -20,6 +20,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): alias = lora_on_disk.get_alias() + # in 1.5 filename changes to be full filename instead of path without extension, and metadata is dict instead of json string item = { "name": name, "filename": lora_on_disk.filename, @@ -27,17 +28,17 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "description": self.find_description(path), "search_term": self.search_terms_from_path(lora_on_disk.filename), "local_preview": f"{path}.{shared.opts.samples_format}", - "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, + "metadata": lora_on_disk.metadata, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, } self.read_user_metadata(item) activation_text = item["user_metadata"].get("activation text") preferred_weight = item["user_metadata"].get("preferred weight", 0.0) - item["prompt"] = json.dumps(f"") + item["prompt"] = quote_js(f"") if activation_text: - item["prompt"] += " + " + json.dumps(" " + activation_text) + item["prompt"] += " + " + quote_js(" " + activation_text) return item diff --git a/modules/cache.py b/modules/cache.py index 4c2db604..07180602 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -1,12 +1,12 @@ import json import os.path - -import filelock +import threading from modules.paths import data_path, script_path cache_filename = os.path.join(data_path, "cache.json") cache_data = None +cache_lock = threading.Lock() def dump_cache(): @@ -14,7 +14,7 @@ def dump_cache(): Saves all cache data to a file. """ - with filelock.FileLock(f"{cache_filename}.lock"): + with cache_lock: with open(cache_filename, "w", encoding="utf8") as file: json.dump(cache_data, file, indent=4) @@ -33,17 +33,18 @@ def cache(subsection): global cache_data if cache_data is None: - with filelock.FileLock(f"{cache_filename}.lock"): - if not os.path.isfile(cache_filename): - cache_data = {} - else: - try: - with open(cache_filename, "r", encoding="utf8") as file: - cache_data = json.load(file) - except Exception: - os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) - print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') + with cache_lock: + if cache_data is None: + if not os.path.isfile(cache_filename): cache_data = {} + else: + try: + with open(cache_filename, "r", encoding="utf8") as file: + cache_data = json.load(file) + except Exception: + os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) + print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') + cache_data = {} s = cache_data.get(subsection, {}) cache_data[subsection] = s diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 42c4d0ac..f9d1fa31 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -73,6 +73,12 @@ def add_pages_to_demo(app): app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"]) +def quote_js(s): + s = s.replace('\\', '\\\\') + s = s.replace('"', '\\"') + return f'"{s}"' + + class ExtraNetworksPage: def __init__(self, title): self.title = title @@ -203,7 +209,7 @@ class ExtraNetworksPage: onclick = item.get("onclick", None) if onclick is None: - onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' @@ -211,9 +217,9 @@ class ExtraNetworksPage: metadata_button = "" metadata = item.get("metadata") if metadata: - metadata_button = f"" + metadata_button = f"" - edit_button = f"
" + edit_button = f"
" local_path = "" filename = item.get("filename", "") @@ -239,12 +245,12 @@ class ExtraNetworksPage: "background_image": background_image, "style": f"'display: none; {height}{width}'", "prompt": item.get("prompt", None), - "tabname": json.dumps(tabname), - "local_preview": json.dumps(item["local_preview"]), + "tabname": quote_js(tabname), + "local_preview": quote_js(item["local_preview"]), "name": item["name"], "description": (item.get("description") or ""), "card_clicked": onclick, - "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', + "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', "search_term": item.get("search_term", ""), "metadata_button": metadata_button, "edit_button": edit_button, @@ -359,7 +365,7 @@ def create_ui(container, button, tabname): page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) - page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[]) + page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) editor = page.create_user_metadata_editor(ui, tabname) editor.create_ui() diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index ef8cdf35..e73b5b1f 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -1,8 +1,8 @@ import html -import json import os from modules import shared, ui_extra_networks, sd_models +from modules.ui_extra_networks import quote_js class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): @@ -21,7 +21,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), - "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', + "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', "local_preview": f"{path}.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 8dae23c6..e53ccb42 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -1,7 +1,7 @@ -import json import os from modules import shared, ui_extra_networks +from modules.ui_extra_networks import quote_js class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): @@ -21,7 +21,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(path), - "prompt": json.dumps(f""), + "prompt": quote_js(f""), "local_preview": f"{path}.preview.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, } diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 159f2d64..d1794e50 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -1,7 +1,7 @@ -import json import os from modules import ui_extra_networks, sd_hijack, shared +from modules.ui_extra_networks import quote_js class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): @@ -22,7 +22,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(embedding.filename), - "prompt": json.dumps(embedding.name), + "prompt": quote_js(embedding.name), "local_preview": f"{path}.preview.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, } From ccd97886da1f659472cdca3de8731f59a70bbc28 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 09:49:22 +0300 Subject: [PATCH 154/178] fix bogus metadata for extra networks appearing out of cache fix description editing for checkpoint not immediately appearing on cards --- modules/cache.py | 10 +++++----- modules/ui_extra_networks.py | 3 ++- modules/ui_extra_networks_checkpoints.py | 3 +-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/modules/cache.py b/modules/cache.py index 07180602..28d42a8c 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -80,18 +80,18 @@ def cached_data_for_file(subsection, title, filename, func): entry = existing_cache.get(title) if entry: - cached_mtime = existing_cache[title].get("mtime", 0) + cached_mtime = entry.get("mtime", 0) if ondisk_mtime > cached_mtime: entry = None if not entry: - entry = func() - if entry is None: + value = func() + if value is None: return None - entry['mtime'] = ondisk_mtime + entry = {'mtime': ondisk_mtime, 'value': value} existing_cache[title] = entry dump_cache() - return entry + return entry['value'] diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 760fba43..a4927c11 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -52,7 +52,7 @@ def get_metadata(page: str = "", item: str = ""): if metadata is None: return JSONResponse({}) - return JSONResponse({"metadata": metadata}) + return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)}) def get_single_card(page: str = "", tabname: str = "", name: str = ""): @@ -66,6 +66,7 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""): errors.display(e, "creating item for extra network") item = page.items.get(name) + page.read_user_metadata(item) item_html = page.create_html_for_item(item, tabname) return JSONResponse({"html": item_html}) diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index e73b5b1f..76780cfd 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -13,7 +13,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): shared.refresh_checkpoints() def create_item(self, name, index=None): - checkpoint: sd_models.CheckpointInfo = sd_models.checkpoints_list.get(name) + checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name) path, ext = os.path.splitext(checkpoint.filename) return { "name": checkpoint.name_for_extra, @@ -24,7 +24,6 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', "local_preview": f"{path}.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, - } def list_items(self): From 7b052eb70eb2a35ce4f776b1e2ab1389802a41b5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 10:07:02 +0300 Subject: [PATCH 155/178] add resolution calculation from buckets for lora user metadata page --- extensions-builtin/Lora/lora.py | 1 - .../Lora/ui_edit_user_metadata.py | 28 +++++++++++++++---- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index c8710922..467ad65f 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -86,7 +86,6 @@ class LoraOnDisk: if self.is_safetensors: try: - #self.metadata = sd_models.read_metadata_from_safetensors(filename) self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata) except Exception as e: errors.display(e, f"reading lora {filename}") diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index 6db63b09..354a1d68 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -65,17 +65,33 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) item = self.page.items.get(name, {}) metadata = item.get("metadata") or {} - keys = [ - ('ss_sd_model_name', "Model:"), - ('ss_resolution', "Resolution:"), - ('ss_clip_skip', "Clip skip:"), - ] + keys = { + 'ss_sd_model_name': "Model:", + 'ss_clip_skip': "Clip skip:", + } - for key, label in keys: + for key, label in keys.items(): value = metadata.get(key, None) if value is not None and str(value) != "None": table.append((label, html.escape(value))) + ss_bucket_info = metadata.get("ss_bucket_info") + if ss_bucket_info and "buckets" in ss_bucket_info: + resolutions = {} + for _, bucket in ss_bucket_info["buckets"].items(): + resolution = bucket["resolution"] + resolution = f'{resolution[1]}x{resolution[0]}' + + resolutions[resolution] = resolutions.get(resolution, 0) + int(bucket["count"]) + + resolutions_list = sorted(resolutions.keys(), key=resolutions.get, reverse=True) + resolutions_text = html.escape(", ".join(resolutions_list[0:4])) + if len(resolutions) > 4: + resolutions_text += ", ..." + resolutions_text = f"{resolutions_text}" + + table.append(('Resolutions:' if len(resolutions_list) > 1 else 'Resolution:', resolutions_text)) + image_count = 0 for _, params in metadata.get("ss_dataset_dirs", {}).items(): image_count += int(params.get("img_count", 0)) From 690d56f3c10e5359e15eeba9c68e56b2eb193ac3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 10:25:34 +0300 Subject: [PATCH 156/178] nuke thumbs extra networks view mode (use settings tab to change width/height/scale to get thumbs) --- html/image-update.svg | 7 --- javascript/hints.js | 1 - modules/shared.py | 5 +- modules/ui_extra_networks.py | 9 ++-- style.css | 94 ++++++------------------------------ 5 files changed, 22 insertions(+), 94 deletions(-) delete mode 100644 html/image-update.svg diff --git a/html/image-update.svg b/html/image-update.svg deleted file mode 100644 index 3abf12df..00000000 --- a/html/image-update.svg +++ /dev/null @@ -1,7 +0,0 @@ - - - - - - - diff --git a/javascript/hints.js b/javascript/hints.js index 41201b2f..4167cb28 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -108,7 +108,6 @@ var titles = { "Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.", "Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.", "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.", - "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.", "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.", "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order listed.", "Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction." diff --git a/modules/shared.py b/modules/shared.py index 427dcc50..f6604ef9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -466,10 +466,11 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('extra_networks', "Extra Networks"), { "extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."), "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'), - "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}), - "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}), "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"), "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), + "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"), + "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(), "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index a4927c11..d9deccb2 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -132,7 +132,6 @@ class ExtraNetworksPage: return "" def create_html(self, tabname): - view = shared.opts.extra_networks_default_view items_html = '' self.metadata = {} @@ -186,10 +185,10 @@ class ExtraNetworksPage: self_name_id = self.name.replace(" ", "_") res = f""" -
+
{subdirs_html}
-
+
{items_html}
""" @@ -248,12 +247,12 @@ class ExtraNetworksPage: args = { "background_image": background_image, - "style": f"'display: none; {height}{width}'", + "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", "prompt": item.get("prompt", None), "tabname": quote_js(tabname), "local_preview": quote_js(item["local_preview"]), "name": item["name"], - "description": (item.get("description") or ""), + "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), "card_clicked": onclick, "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', "search_term": item.get("search_term", ""), diff --git a/style.css b/style.css index 4e22cfd6..f7a8cb17 100644 --- a/style.css +++ b/style.css @@ -804,127 +804,67 @@ footer { width: auto; } -.extra-network-cards .nocards, .extra-network-thumbs .nocards{ +.extra-network-cards .nocards{ margin: 1.25em 0.5em 0.5em 0.5em; } -.extra-network-cards .nocards h1, .extra-network-thumbs .nocards h1{ +.extra-network-cards .nocards h1{ font-size: 1.5em; margin-bottom: 1em; } -.extra-network-cards .nocards li, .extra-network-thumbs .nocards li{ +.extra-network-cards .nocards li{ margin-left: 0.5em; } -.extra-network-cards .card .button-row, .extra-network-thumbs .card .button-row{ +.extra-network-cards .card .button-row{ display: none; position: absolute; color: white; right: 0; } -.extra-network-cards .card:hover .button-row, .extra-network-thumbs .card:hover .button-row{ +.extra-network-cards .card:hover .button-row{ display: flex; } -.extra-network-cards .card .card-button, .extra-network-thumbs .card .card-button{ +.extra-network-cards .card .card-button{ color: white; } -.extra-network-cards .card .metadata-button:before, .extra-network-thumbs .card .metadata-button:before{ +.extra-network-cards .card .metadata-button:before{ content: "🛈"; } -.extra-network-cards .card .edit-button:before, .extra-network-thumbs .card .edit-button:before{ +.extra-network-cards .card .edit-button:before{ content: "🛠"; } .extra-network-cards .card .card-button { text-shadow: 2px 2px 3px black; padding: 0.25em; - font-size: 22pt; + font-size: 200%; width: 1.5em; } -.extra-network-thumbs .card .card-button { - text-shadow: 1px 1px 2px black; - padding: 0; - font-size: 16pt; - width: 1em; - top: -0.25em; -} -.extra-network-cards .card .card-button:hover, .extra-network-thumbs .card .card-button:hover{ +.extra-network-cards .card .card-button:hover{ color: red; } -.extra-network-thumbs { - display: flex; - flex-flow: row wrap; - gap: 10px; -} - -.extra-network-thumbs .card { - height: 6em; - width: 6em; - cursor: pointer; - background-image: url('./file=html/card-no-preview.png'); - background-size: cover; - background-position: center center; - position: relative; -} - -.extra-network-thumbs .card .preview, .standalone-card-preview.card .preview{ +.standalone-card-preview.card .preview{ position: absolute; object-fit: cover; width: 100%; height:100%; } -.extra-network-thumbs .card:hover .additional a { - display: inline-block; -} - -.extra-network-thumbs .actions .additional a { - background-image: url('./file=html/image-update.svg'); - background-repeat: no-repeat; - background-size: cover; - background-position: center center; - position: absolute; - top: 0; - left: 0; - width: 24px; - height: 24px; - display: none; - font-size: 0; - text-align: -9999; -} - -.extra-network-thumbs .actions .name { - position: absolute; - bottom: 0; - font-size: 10px; - padding: 3px; - width: 100%; - overflow: hidden; - white-space: nowrap; - text-overflow: ellipsis; - background: rgba(0,0,0,.5); - color: white; -} - -.extra-network-thumbs .card:hover .actions .name { - white-space: normal; - word-break: break-all; -} - .extra-network-cards .card, .standalone-card-preview.card{ display: inline-block; - margin: 0.5em; - width: 16em; - height: 24em; + margin: 0.5rem; + width: 16rem; + height: 24rem; box-shadow: 0 0 5px rgba(128, 128, 128, 0.5); - border-radius: 0.2em; + border-radius: 0.2rem; position: relative; background-size: auto 100%; @@ -958,10 +898,6 @@ footer { color: white; } -.extra-network-cards .card .actions:hover{ - box-shadow: 0 0 0.75em 0.75em rgba(0,0,0,0.5) !important; -} - .extra-network-cards .card .actions .name{ font-size: 1.7em; font-weight: bold; From 9d3dd64fe9e95873347710ca1df1f1e88d1908e1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 10:44:04 +0300 Subject: [PATCH 157/178] minor restyling for extra networks --- modules/ui_extra_networks.py | 3 ++- style.css | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index d9deccb2..6c73998f 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -11,6 +11,7 @@ import html from fastapi.exceptions import HTTPException from modules.generation_parameters_copypaste import image_from_url_text +from modules.ui_components import ToolButton extra_pages = [] allowed_dirs = set() @@ -377,7 +378,7 @@ def create_ui(container, button, tabname): gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True) - gr.Button(up_down_symbol, elem_id=tabname+"_extra_sortorder") + ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) diff --git a/style.css b/style.css index f7a8cb17..8a66c3d2 100644 --- a/style.css +++ b/style.css @@ -783,8 +783,7 @@ footer { margin: 0 0.15em; } .extra-networks .tab-nav .search, -.extra-networks .tab-nav .sort, -.extra-networks .tab-nav .sortorder{ +.extra-networks .tab-nav .sort{ display: inline-block; margin: 0.3em; align-self: center; From 570f42afd122405116b39b880cdb5163fd5ca3e2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 12:28:50 +0300 Subject: [PATCH 158/178] possible fix for FP16 VAE failing in img2img SDXL --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index e7b10808..6567b3cf 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1303,7 +1303,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images) image = 2. * image - 1. - image = image.to(shared.device) + image = image.to(shared.device, dtype=devices.dtype_vae) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) From 67ea4eabc3e78c4b496a9fcd21aca95fd5ef7027 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 13:46:33 +0300 Subject: [PATCH 159/178] fix cache loading wrong entries from old cache files --- modules/cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/cache.py b/modules/cache.py index 28d42a8c..ddf44637 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -84,7 +84,7 @@ def cached_data_for_file(subsection, title, filename, func): if ondisk_mtime > cached_mtime: entry = None - if not entry: + if not entry or 'value' not in entry: value = func() if value is None: return None From 7d26c479eebec03c2abb28f7b5226791688a7cea Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 14:39:47 +0300 Subject: [PATCH 160/178] changelog for future 1.5.0 --- CHANGELOG.md | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 925403a9..30783d9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,63 @@ +## 1.5.0 + +### Features: + * SD XL support + * user metadata system for custom networks + * extended Lora metadata editor: set activation text, default weight, view tags, training info + * show github stars for extenstions + * img2img batch mode can read extra stuff from png info + * img2img batch works with subdirectories + * hotkeys to move prompt elements: alt+left/right + * restyle time taken/VRAM display + * add textual inversion hashes to infotext + * optimization: cache git extension repo information + +### Minor: + * checkbox to check/uncheck all extensions in the Installed tab + * add gradio user to infotext and to filename patterns + * allow gif for extra network previews + * add options to change colors in grid + * use natural sort for items in extra networks + * Mac: use empty_cache() from torch 2 to clear VRAM + * added automatic support for installing the right libraries for Navi3 (AMD) + * add option SWIN_torch_compile to accelerate SwinIR upscale + * suppress printing TI embedding info at start to console by default + * speedup extra networks listing + * added `[none]` filename token. + * removed thumbs extra networks view mode (use settings tab to change width/height/scale to get thumbs) + +### Extensions and API: + * api endpoints: /sdapi/v1/server-kill, /sdapi/v1/server-restart, /sdapi/v1/server-stop + * allow Script to have custom metaclass + * add model exists status check /sdapi/v1/options + * rename --add-stop-route to --api-server-stop + * add `before_hr` script callback + * add callback `after_extra_networks_activate` + * disable rich exception output in console for API by default, use WEBUI_RICH_EXCEPTIONS env var to enable + * return http 404 when thumb file not found + * allow replacing extensions index with environment variable + +### Bug Fixes: + * fix for catch errors when retrieving extension index #11290 + * fix very slow loading speed of .safetensors files when reading from network drives + * API cache cleanup + * fix UnicodeEncodeError when writing to file CLIP Interrogator batch mode + * fix warning of 'has_mps' deprecated from PyTorch + * fix problem with extra network saving images as previews losing generation info + * fix throwing exception when trying to resize image with I;16 mode + * fix for #11534: canvas zoom and pan extension hijacking shortcut keys + * fixed launch script to be runnable from any directory + * don't add "Seed Resize: -1x-1" to API image metadata + * correctly remove end parenthesis with ctrl+up/down + * fixing --subpath on newer gradio version + * fix: check fill size none zero when resize (fixes #11425) + * use submit and blur for quick settings textbox + * save img2img batch with images.save_image() + * + + + + ## 1.4.1 ### Bug Fixes: From b75b004fe62826455f1aa77e849e7da13902cb17 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 23:13:55 +0300 Subject: [PATCH 161/178] lora extension rework to include other types of networks --- .../Lora/extra_networks_lora.py | 18 +- extensions-builtin/Lora/lora.py | 537 ------------------ extensions-builtin/Lora/lyco_helpers.py | 15 + extensions-builtin/Lora/network.py | 98 ++++ extensions-builtin/Lora/network_hada.py | 59 ++ extensions-builtin/Lora/network_lora.py | 70 +++ extensions-builtin/Lora/network_lyco.py | 39 ++ extensions-builtin/Lora/networks.py | 443 +++++++++++++++ .../Lora/scripts/lora_script.py | 79 +-- .../Lora/ui_extra_networks_lora.py | 8 +- 10 files changed, 777 insertions(+), 589 deletions(-) delete mode 100644 extensions-builtin/Lora/lora.py create mode 100644 extensions-builtin/Lora/lyco_helpers.py create mode 100644 extensions-builtin/Lora/network.py create mode 100644 extensions-builtin/Lora/network_hada.py create mode 100644 extensions-builtin/Lora/network_lora.py create mode 100644 extensions-builtin/Lora/network_lyco.py create mode 100644 extensions-builtin/Lora/networks.py diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index 66ee9c85..8a6639cf 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -1,5 +1,5 @@ from modules import extra_networks, shared -import lora +import networks class ExtraNetworkLora(extra_networks.ExtraNetwork): @@ -9,7 +9,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): def activate(self, p, params_list): additional = shared.opts.sd_lora - if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional): + if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional): p.all_prompts = [x + f"" for x in p.all_prompts] params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) @@ -21,12 +21,12 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): names.append(params.items[0]) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) - lora.load_loras(names, multipliers) + networks.load_networks(names, multipliers) if shared.opts.lora_add_hashes_to_infotext: - lora_hashes = [] - for item in lora.loaded_loras: - shorthash = item.lora_on_disk.shorthash + network_hashes = [] + for item in networks.loaded_networks: + shorthash = item.network_on_disk.shorthash if not shorthash: continue @@ -36,10 +36,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): alias = alias.replace(":", "").replace(",", "") - lora_hashes.append(f"{alias}: {shorthash}") + network_hashes.append(f"{alias}: {shorthash}") - if lora_hashes: - p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes) + if network_hashes: + p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes) def deactivate(self, p): pass diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py deleted file mode 100644 index 9cdff6ed..00000000 --- a/extensions-builtin/Lora/lora.py +++ /dev/null @@ -1,537 +0,0 @@ -import os -import re -import torch -from typing import Union - -from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes, cache - -metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} - -re_digits = re.compile(r"\d+") -re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") -re_compiled = {} - -suffix_conversion = { - "attentions": {}, - "resnets": { - "conv1": "in_layers_2", - "conv2": "out_layers_3", - "time_emb_proj": "emb_layers_1", - "conv_shortcut": "skip_connection", - } -} - - -def convert_diffusers_name_to_compvis(key, is_sd2): - def match(match_list, regex_text): - regex = re_compiled.get(regex_text) - if regex is None: - regex = re.compile(regex_text) - re_compiled[regex_text] = regex - - r = re.match(regex, key) - if not r: - return False - - match_list.clear() - match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) - return True - - m = [] - - if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): - suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) - return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" - - if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): - suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) - return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" - - if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): - suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) - return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" - - if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): - return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" - - if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): - return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" - - if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): - if is_sd2: - if 'mlp_fc1' in m[1]: - return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" - elif 'mlp_fc2' in m[1]: - return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" - else: - return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" - - return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" - - if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): - if 'mlp_fc1' in m[1]: - return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" - elif 'mlp_fc2' in m[1]: - return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" - else: - return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" - - return key - - -class LoraOnDisk: - def __init__(self, name, filename): - self.name = name - self.filename = filename - self.metadata = {} - self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" - - def read_metadata(): - metadata = sd_models.read_metadata_from_safetensors(filename) - metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text - - return metadata - - if self.is_safetensors: - try: - self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata) - except Exception as e: - errors.display(e, f"reading lora {filename}") - - if self.metadata: - m = {} - for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): - m[k] = v - - self.metadata = m - - self.alias = self.metadata.get('ss_output_name', self.name) - - self.hash = None - self.shorthash = None - self.set_hash( - self.metadata.get('sshs_model_hash') or - hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or - '' - ) - - def set_hash(self, v): - self.hash = v - self.shorthash = self.hash[0:12] - - if self.shorthash: - available_lora_hash_lookup[self.shorthash] = self - - def read_hash(self): - if not self.hash: - self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '') - - def get_alias(self): - if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in forbidden_lora_aliases: - return self.name - else: - return self.alias - - -class LoraModule: - def __init__(self, name, lora_on_disk: LoraOnDisk): - self.name = name - self.lora_on_disk = lora_on_disk - self.multiplier = 1.0 - self.modules = {} - self.mtime = None - - self.mentioned_name = None - """the text that was used to add lora to prompt - can be either name or an alias""" - - -class LoraUpDownModule: - def __init__(self): - self.up = None - self.down = None - self.alpha = None - - -def assign_lora_names_to_compvis_modules(sd_model): - lora_layer_mapping = {} - - if shared.sd_model.is_sdxl: - for i, embedder in enumerate(shared.sd_model.conditioner.embedders): - if not hasattr(embedder, 'wrapped'): - continue - - for name, module in embedder.wrapped.named_modules(): - lora_name = f'{i}_{name.replace(".", "_")}' - lora_layer_mapping[lora_name] = module - module.lora_layer_name = lora_name - else: - for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): - lora_name = name.replace(".", "_") - lora_layer_mapping[lora_name] = module - module.lora_layer_name = lora_name - - for name, module in shared.sd_model.model.named_modules(): - lora_name = name.replace(".", "_") - lora_layer_mapping[lora_name] = module - module.lora_layer_name = lora_name - - sd_model.lora_layer_mapping = lora_layer_mapping - - -def load_lora(name, lora_on_disk): - lora = LoraModule(name, lora_on_disk) - lora.mtime = os.path.getmtime(lora_on_disk.filename) - - sd = sd_models.read_state_dict(lora_on_disk.filename) - - # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 - if not hasattr(shared.sd_model, 'lora_layer_mapping'): - assign_lora_names_to_compvis_modules(shared.sd_model) - - keys_failed_to_match = {} - is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping - - for key_lora, weight in sd.items(): - key_lora_without_lora_parts, lora_key = key_lora.split(".", 1) - - key = convert_diffusers_name_to_compvis(key_lora_without_lora_parts, is_sd2) - sd_module = shared.sd_model.lora_layer_mapping.get(key, None) - - if sd_module is None: - m = re_x_proj.match(key) - if m: - sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) - - # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" - if sd_module is None and "lora_unet" in key_lora_without_lora_parts: - key = key_lora_without_lora_parts.replace("lora_unet", "diffusion_model") - sd_module = shared.sd_model.lora_layer_mapping.get(key, None) - elif sd_module is None and "lora_te1_text_model" in key_lora_without_lora_parts: - key = key_lora_without_lora_parts.replace("lora_te1_text_model", "0_transformer_text_model") - sd_module = shared.sd_model.lora_layer_mapping.get(key, None) - - if sd_module is None: - keys_failed_to_match[key_lora] = key - continue - - lora_module = lora.modules.get(key, None) - if lora_module is None: - lora_module = LoraUpDownModule() - lora.modules[key] = lora_module - - if lora_key == "alpha": - lora_module.alpha = weight.item() - continue - - if type(sd_module) == torch.nn.Linear: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(sd_module) == torch.nn.MultiheadAttention: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1): - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) - elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3): - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False) - else: - print(f'Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}') - continue - raise AssertionError(f"Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}") - - with torch.no_grad(): - module.weight.copy_(weight) - - module.to(device=devices.cpu, dtype=devices.dtype) - - if lora_key == "lora_up.weight": - lora_module.up = module - elif lora_key == "lora_down.weight": - lora_module.down = module - else: - raise AssertionError(f"Bad Lora layer name: {key_lora} - must end in lora_up.weight, lora_down.weight or alpha") - - if keys_failed_to_match: - print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}") - - return lora - - -def load_loras(names, multipliers=None): - already_loaded = {} - - for lora in loaded_loras: - if lora.name in names: - already_loaded[lora.name] = lora - - loaded_loras.clear() - - loras_on_disk = [available_lora_aliases.get(name, None) for name in names] - if any(x is None for x in loras_on_disk): - list_available_loras() - - loras_on_disk = [available_lora_aliases.get(name, None) for name in names] - - failed_to_load_loras = [] - - for i, name in enumerate(names): - lora = already_loaded.get(name, None) - - lora_on_disk = loras_on_disk[i] - - if lora_on_disk is not None: - if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime: - try: - lora = load_lora(name, lora_on_disk) - except Exception as e: - errors.display(e, f"loading Lora {lora_on_disk.filename}") - continue - - lora.mentioned_name = name - - lora_on_disk.read_hash() - - if lora is None: - failed_to_load_loras.append(name) - print(f"Couldn't find Lora with name {name}") - continue - - lora.multiplier = multipliers[i] if multipliers else 1.0 - loaded_loras.append(lora) - - if failed_to_load_loras: - sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras)) - - -def lora_calc_updown(lora, module, target): - with torch.no_grad(): - up = module.up.weight.to(target.device, dtype=target.dtype) - down = module.down.weight.to(target.device, dtype=target.dtype) - - if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): - updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) - elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): - updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) - else: - updown = up @ down - - updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) - - return updown - - -def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): - weights_backup = getattr(self, "lora_weights_backup", None) - - if weights_backup is None: - return - - if isinstance(self, torch.nn.MultiheadAttention): - self.in_proj_weight.copy_(weights_backup[0]) - self.out_proj.weight.copy_(weights_backup[1]) - else: - self.weight.copy_(weights_backup) - - -def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): - """ - Applies the currently selected set of Loras to the weights of torch layer self. - If weights already have this particular set of loras applied, does nothing. - If not, restores orginal weights from backup and alters weights according to loras. - """ - - lora_layer_name = getattr(self, 'lora_layer_name', None) - if lora_layer_name is None: - return - - current_names = getattr(self, "lora_current_names", ()) - wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) - - weights_backup = getattr(self, "lora_weights_backup", None) - if weights_backup is None: - if isinstance(self, torch.nn.MultiheadAttention): - weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) - else: - weights_backup = self.weight.to(devices.cpu, copy=True) - - self.lora_weights_backup = weights_backup - - if current_names != wanted_names: - lora_restore_weights_from_backup(self) - - for lora in loaded_loras: - module = lora.modules.get(lora_layer_name, None) - if module is not None and hasattr(self, 'weight'): - self.weight += lora_calc_updown(lora, module, self.weight) - continue - - module_q = lora.modules.get(lora_layer_name + "_q_proj", None) - module_k = lora.modules.get(lora_layer_name + "_k_proj", None) - module_v = lora.modules.get(lora_layer_name + "_v_proj", None) - module_out = lora.modules.get(lora_layer_name + "_out_proj", None) - - if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: - updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight) - updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight) - updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight) - updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) - - self.in_proj_weight += updown_qkv - self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight) - continue - - if module is None: - continue - - print(f'failed to calculate lora weights for layer {lora_layer_name}') - - self.lora_current_names = wanted_names - - -def lora_forward(module, input, original_forward): - """ - Old way of applying Lora by executing operations during layer's forward. - Stacking many loras this way results in big performance degradation. - """ - - if len(loaded_loras) == 0: - return original_forward(module, input) - - input = devices.cond_cast_unet(input) - - lora_restore_weights_from_backup(module) - lora_reset_cached_weight(module) - - res = original_forward(module, input) - - lora_layer_name = getattr(module, 'lora_layer_name', None) - for lora in loaded_loras: - module = lora.modules.get(lora_layer_name, None) - if module is None: - continue - - module.up.to(device=devices.device) - module.down.to(device=devices.device) - - res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) - - return res - - -def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): - self.lora_current_names = () - self.lora_weights_backup = None - - -def lora_Linear_forward(self, input): - if shared.opts.lora_functional: - return lora_forward(self, input, torch.nn.Linear_forward_before_lora) - - lora_apply_weights(self) - - return torch.nn.Linear_forward_before_lora(self, input) - - -def lora_Linear_load_state_dict(self, *args, **kwargs): - lora_reset_cached_weight(self) - - return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) - - -def lora_Conv2d_forward(self, input): - if shared.opts.lora_functional: - return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora) - - lora_apply_weights(self) - - return torch.nn.Conv2d_forward_before_lora(self, input) - - -def lora_Conv2d_load_state_dict(self, *args, **kwargs): - lora_reset_cached_weight(self) - - return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) - - -def lora_MultiheadAttention_forward(self, *args, **kwargs): - lora_apply_weights(self) - - return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs) - - -def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs): - lora_reset_cached_weight(self) - - return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs) - - -def list_available_loras(): - available_loras.clear() - available_lora_aliases.clear() - forbidden_lora_aliases.clear() - available_lora_hash_lookup.clear() - forbidden_lora_aliases.update({"none": 1, "Addams": 1}) - - os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) - - candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) - for filename in candidates: - if os.path.isdir(filename): - continue - - name = os.path.splitext(os.path.basename(filename))[0] - try: - entry = LoraOnDisk(name, filename) - except OSError: # should catch FileNotFoundError and PermissionError etc. - errors.report(f"Failed to load LoRA {name} from {filename}", exc_info=True) - continue - - available_loras[name] = entry - - if entry.alias in available_lora_aliases: - forbidden_lora_aliases[entry.alias.lower()] = 1 - - available_lora_aliases[name] = entry - available_lora_aliases[entry.alias] = entry - - -re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") - - -def infotext_pasted(infotext, params): - if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: - return # if the other extension is active, it will handle those fields, no need to do anything - - added = [] - - for k in params: - if not k.startswith("AddNet Model "): - continue - - num = k[13:] - - if params.get("AddNet Module " + num) != "LoRA": - continue - - name = params.get("AddNet Model " + num) - if name is None: - continue - - m = re_lora_name.match(name) - if m: - name = m.group(1) - - multiplier = params.get("AddNet Weight A " + num, "1.0") - - added.append(f"") - - if added: - params["Prompt"] += "\n" + "".join(added) - - -available_loras = {} -available_lora_aliases = {} -available_lora_hash_lookup = {} -forbidden_lora_aliases = {} -loaded_loras = [] - -list_available_loras() diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py new file mode 100644 index 00000000..9ea499fb --- /dev/null +++ b/extensions-builtin/Lora/lyco_helpers.py @@ -0,0 +1,15 @@ +import torch + + +def make_weight_cp(t, wa, wb): + temp = torch.einsum('i j k l, j r -> i r k l', t, wb) + return torch.einsum('i j k l, i r -> r j k l', temp, wa) + + +def rebuild_conventional(up, down, shape, dyn_dim=None): + up = up.reshape(up.size(0), -1) + down = down.reshape(down.size(0), -1) + if dyn_dim is not None: + up = up[:, :dyn_dim] + down = down[:dyn_dim, :] + return (up @ down).reshape(shape) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py new file mode 100644 index 00000000..a1fe6bbf --- /dev/null +++ b/extensions-builtin/Lora/network.py @@ -0,0 +1,98 @@ +import os +from collections import namedtuple + +import torch + +from modules import devices, sd_models, cache, errors, hashes, shared + +NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) + +metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} + + +class NetworkOnDisk: + def __init__(self, name, filename): + self.name = name + self.filename = filename + self.metadata = {} + self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" + + def read_metadata(): + metadata = sd_models.read_metadata_from_safetensors(filename) + metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text + + return metadata + + if self.is_safetensors: + try: + self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata) + except Exception as e: + errors.display(e, f"reading lora {filename}") + + if self.metadata: + m = {} + for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): + m[k] = v + + self.metadata = m + + self.alias = self.metadata.get('ss_output_name', self.name) + + self.hash = None + self.shorthash = None + self.set_hash( + self.metadata.get('sshs_model_hash') or + hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or + '' + ) + + def set_hash(self, v): + self.hash = v + self.shorthash = self.hash[0:12] + + if self.shorthash: + import networks + networks.available_network_hash_lookup[self.shorthash] = self + + def read_hash(self): + if not self.hash: + self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '') + + def get_alias(self): + import networks + if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases: + return self.name + else: + return self.alias + + +class Network: # LoraModule + def __init__(self, name, network_on_disk: NetworkOnDisk): + self.name = name + self.network_on_disk = network_on_disk + self.multiplier = 1.0 + self.modules = {} + self.mtime = None + + self.mentioned_name = None + """the text that was used to add the network to prompt - can be either name or an alias""" + + +class ModuleType: + def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: + return None + + +class NetworkModule: + def __init__(self, net: Network, weights: NetworkWeights): + self.network = net + self.network_key = weights.network_key + self.sd_key = weights.sd_key + self.sd_module = weights.sd_module + + def calc_updown(self, target): + raise NotImplementedError() + + def forward(self, x, y): + raise NotImplementedError() + diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py new file mode 100644 index 00000000..15e7ffd8 --- /dev/null +++ b/extensions-builtin/Lora/network_hada.py @@ -0,0 +1,59 @@ +import lyco_helpers +import network +import network_lyco + + +class ModuleTypeHada(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]): + return NetworkModuleHada(net, weights) + + return None + + +class NetworkModuleHada(network_lyco.NetworkModuleLyco): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.w1a = weights.w["hada_w1_a"] + self.w1b = weights.w["hada_w1_b"] + self.dim = self.w1b.shape[0] + self.w2a = weights.w["hada_w2_a"] + self.w2b = weights.w["hada_w2_b"] + + self.t1 = weights.w.get("hada_t1") + self.t2 = weights.w.get("hada_t2") + + self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None + self.scale = weights.w["scale"].item() if "scale" in weights.w else None + + def calc_updown(self, orig_weight): + w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) + w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) + w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + + output_shape = [w1a.size(0), w1b.size(1)] + + if self.t1 is not None: + output_shape = [w1a.size(1), w1b.size(1)] + t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype) + updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) + output_shape += t1.shape[2:] + else: + if len(w1b.shape) == 4: + output_shape += w1b.shape[2:] + updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) + + if self.t2 is not None: + t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) + updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) + else: + updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) + + updown = updown1 * updown2 + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py new file mode 100644 index 00000000..b2d96537 --- /dev/null +++ b/extensions-builtin/Lora/network_lora.py @@ -0,0 +1,70 @@ +import torch + +import network +from modules import devices + + +class ModuleTypeLora(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]): + return NetworkModuleLora(net, weights) + + return None + + +class NetworkModuleLora(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.up = self.create_module(weights.w["lora_up.weight"]) + self.down = self.create_module(weights.w["lora_down.weight"]) + self.alpha = weights.w["alpha"] if "alpha" in weights.w else None + + def create_module(self, weight, none_ok=False): + if weight is None and none_ok: + return None + + if type(self.sd_module) == torch.nn.Linear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(self.sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(self.sd_module) == torch.nn.MultiheadAttention: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1): + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3): + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False) + else: + print(f'Network layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') + return None + + with torch.no_grad(): + module.weight.copy_(weight) + + module.to(device=devices.cpu, dtype=devices.dtype) + module.weight.requires_grad_(False) + + return module + + def calc_updown(self, target): + up = self.up.weight.to(target.device, dtype=target.dtype) + down = self.down.weight.to(target.device, dtype=target.dtype) + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown = updown * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0) + + return updown + + def forward(self, x, y): + self.up.to(device=devices.device) + self.down.to(device=devices.device) + + return y + self.up(self.down(x)) * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0) + + diff --git a/extensions-builtin/Lora/network_lyco.py b/extensions-builtin/Lora/network_lyco.py new file mode 100644 index 00000000..18a822fa --- /dev/null +++ b/extensions-builtin/Lora/network_lyco.py @@ -0,0 +1,39 @@ +import torch + +import lyco_helpers +import network +from modules import devices + + +class NetworkModuleLyco(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.dim = None + self.bias = weights.w.get("bias") + self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None + self.scale = weights.w["scale"].item() if "scale" in weights.w else None + + def finalize_updown(self, updown, orig_weight, output_shape): + if self.bias is not None: + updown = updown.reshape(self.bias.shape) + updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown = updown.reshape(output_shape) + + if len(output_shape) == 4: + updown = updown.reshape(output_shape) + + if orig_weight.size().numel() == updown.size().numel(): + updown = updown.reshape(orig_weight.shape) + + scale = ( + self.scale if self.scale is not None + else self.alpha / self.dim if self.dim is not None and self.alpha is not None + else 1.0 + ) + + return updown * scale * self.network.multiplier + diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py new file mode 100644 index 00000000..5b0ddfb6 --- /dev/null +++ b/extensions-builtin/Lora/networks.py @@ -0,0 +1,443 @@ +import os +import re + +import network +import network_lora +import network_hada + +import torch +from typing import Union + +from modules import shared, devices, sd_models, errors, scripts, sd_hijack + +module_types = [ + network_lora.ModuleTypeLora(), + network_hada.ModuleTypeHada(), +] + + +re_digits = re.compile(r"\d+") +re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") +re_compiled = {} + +suffix_conversion = { + "attentions": {}, + "resnets": { + "conv1": "in_layers_2", + "conv2": "out_layers_3", + "time_emb_proj": "emb_layers_1", + "conv_shortcut": "skip_connection", + } +} + + +def convert_diffusers_name_to_compvis(key, is_sd2): + def match(match_list, regex_text): + regex = re_compiled.get(regex_text) + if regex is None: + regex = re.compile(regex_text) + re_compiled[regex_text] = regex + + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) + return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" + + if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): + return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" + + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): + return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" + + if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): + if is_sd2: + if 'mlp_fc1' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + + if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): + if 'mlp_fc1' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + + return key + + +def assign_network_names_to_compvis_modules(sd_model): + network_layer_mapping = {} + + if shared.sd_model.is_sdxl: + for i, embedder in enumerate(shared.sd_model.conditioner.embedders): + if not hasattr(embedder, 'wrapped'): + continue + + for name, module in embedder.wrapped.named_modules(): + network_name = f'{i}_{name.replace(".", "_")}' + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + else: + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + network_name = name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + + for name, module in shared.sd_model.model.named_modules(): + network_name = name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + + sd_model.network_layer_mapping = network_layer_mapping + + +def load_network(name, network_on_disk): + net = network.Network(name, network_on_disk) + net.mtime = os.path.getmtime(network_on_disk.filename) + + sd = sd_models.read_state_dict(network_on_disk.filename) + + # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 + if not hasattr(shared.sd_model, 'network_layer_mapping'): + assign_network_names_to_compvis_modules(shared.sd_model) + + keys_failed_to_match = {} + is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping + + matched_networks = {} + + for key_network, weight in sd.items(): + key_network_without_network_parts, network_part = key_network.split(".", 1) + + key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + + if sd_module is None: + m = re_x_proj.match(key) + if m: + sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None) + + # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" + if sd_module is None and "lora_unet" in key_network_without_network_parts: + key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: + key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + + if sd_module is None: + keys_failed_to_match[key_network] = key + continue + + if key not in matched_networks: + matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module) + + matched_networks[key].w[network_part] = weight + + for key, weights in matched_networks.items(): + net_module = None + for nettype in module_types: + net_module = nettype.create_module(net, weights) + if net_module is not None: + break + + if net_module is None: + raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") + + net.modules[key] = net_module + + if keys_failed_to_match: + print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}") + + return net + + +def load_networks(names, multipliers=None): + already_loaded = {} + + for net in loaded_networks: + if net.name in names: + already_loaded[net.name] = net + + loaded_networks.clear() + + networks_on_disk = [available_network_aliases.get(name, None) for name in names] + if any(x is None for x in networks_on_disk): + list_available_networks() + + networks_on_disk = [available_network_aliases.get(name, None) for name in names] + + failed_to_load_networks = [] + + for i, name in enumerate(names): + net = already_loaded.get(name, None) + + network_on_disk = networks_on_disk[i] + + if network_on_disk is not None: + if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime: + try: + net = load_network(name, network_on_disk) + except Exception as e: + errors.display(e, f"loading network {network_on_disk.filename}") + continue + + net.mentioned_name = name + + network_on_disk.read_hash() + + if net is None: + failed_to_load_networks.append(name) + print(f"Couldn't find network with name {name}") + continue + + net.multiplier = multipliers[i] if multipliers else 1.0 + loaded_networks.append(net) + + if failed_to_load_networks: + sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks)) + + +def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): + weights_backup = getattr(self, "network_weights_backup", None) + + if weights_backup is None: + return + + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) + else: + self.weight.copy_(weights_backup) + + +def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): + """ + Applies the currently selected set of networks to the weights of torch layer self. + If weights already have this particular set of networks applied, does nothing. + If not, restores orginal weights from backup and alters weights according to networks. + """ + + network_layer_name = getattr(self, 'network_layer_name', None) + if network_layer_name is None: + return + + current_names = getattr(self, "network_current_names", ()) + wanted_names = tuple((x.name, x.multiplier) for x in loaded_networks) + + weights_backup = getattr(self, "network_weights_backup", None) + if weights_backup is None: + if isinstance(self, torch.nn.MultiheadAttention): + weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) + else: + weights_backup = self.weight.to(devices.cpu, copy=True) + + self.network_weights_backup = weights_backup + + if current_names != wanted_names: + network_restore_weights_from_backup(self) + + for net in loaded_networks: + module = net.modules.get(network_layer_name, None) + if module is not None and hasattr(self, 'weight'): + with torch.no_grad(): + updown = module.calc_updown(self.weight) + + if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: + # inpainting model. zero pad updown to make channel[1] 4 to 9 + updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) + + self.weight += updown + + module_q = net.modules.get(network_layer_name + "_q_proj", None) + module_k = net.modules.get(network_layer_name + "_k_proj", None) + module_v = net.modules.get(network_layer_name + "_v_proj", None) + module_out = net.modules.get(network_layer_name + "_out_proj", None) + + if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: + with torch.no_grad(): + updown_q = module_q.calc_updown(self.in_proj_weight) + updown_k = module_k.calc_updown(self.in_proj_weight) + updown_v = module_v.calc_updown(self.in_proj_weight) + updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) + + self.in_proj_weight += updown_qkv + self.out_proj.weight += module_out.calc_updown(self.out_proj.weight) + continue + + if module is None: + continue + + print(f'failed to calculate network weights for layer {network_layer_name}') + + self.network_current_names = wanted_names + + +def network_forward(module, input, original_forward): + """ + Old way of applying Lora by executing operations during layer's forward. + Stacking many loras this way results in big performance degradation. + """ + + if len(loaded_networks) == 0: + return original_forward(module, input) + + input = devices.cond_cast_unet(input) + + network_restore_weights_from_backup(module) + network_reset_cached_weight(module) + + y = original_forward(module, input) + + network_layer_name = getattr(module, 'network_layer_name', None) + for lora in loaded_networks: + module = lora.modules.get(network_layer_name, None) + if module is None: + continue + + y = module.forward(y, input) + + return y + + +def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): + self.network_current_names = () + self.network_weights_backup = None + + +def network_Linear_forward(self, input): + if shared.opts.lora_functional: + return network_forward(self, input, torch.nn.Linear_forward_before_network) + + network_apply_weights(self) + + return torch.nn.Linear_forward_before_network(self, input) + + +def network_Linear_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs) + + +def network_Conv2d_forward(self, input): + if shared.opts.lora_functional: + return network_forward(self, input, torch.nn.Conv2d_forward_before_network) + + network_apply_weights(self) + + return torch.nn.Conv2d_forward_before_network(self, input) + + +def network_Conv2d_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs) + + +def network_MultiheadAttention_forward(self, *args, **kwargs): + network_apply_weights(self) + + return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs) + + +def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs) + + +def list_available_networks(): + available_networks.clear() + available_network_aliases.clear() + forbidden_network_aliases.clear() + available_network_hash_lookup.clear() + forbidden_network_aliases.update({"none": 1, "Addams": 1}) + + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) + + candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) + for filename in candidates: + if os.path.isdir(filename): + continue + + name = os.path.splitext(os.path.basename(filename))[0] + try: + entry = network.NetworkOnDisk(name, filename) + except OSError: # should catch FileNotFoundError and PermissionError etc. + errors.report(f"Failed to load network {name} from {filename}", exc_info=True) + continue + + available_networks[name] = entry + + if entry.alias in available_network_aliases: + forbidden_network_aliases[entry.alias.lower()] = 1 + + available_network_aliases[name] = entry + available_network_aliases[entry.alias] = entry + + +re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") + + +def infotext_pasted(infotext, params): + if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: + return # if the other extension is active, it will handle those fields, no need to do anything + + added = [] + + for k in params: + if not k.startswith("AddNet Model "): + continue + + num = k[13:] + + if params.get("AddNet Module " + num) != "LoRA": + continue + + name = params.get("AddNet Model " + num) + if name is None: + continue + + m = re_network_name.match(name) + if m: + name = m.group(1) + + multiplier = params.get("AddNet Weight A " + num, "1.0") + + added.append(f"") + + if added: + params["Prompt"] += "\n" + "".join(added) + + +available_networks = {} +available_network_aliases = {} +loaded_networks = [] +available_network_hash_lookup = {} +forbidden_network_aliases = {} + +list_available_networks() diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index e650f469..81e6572a 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -4,18 +4,19 @@ import torch import gradio as gr from fastapi import FastAPI -import lora +import network +import networks import extra_networks_lora import ui_extra_networks_lora from modules import script_callbacks, ui_extra_networks, extra_networks, shared def unload(): - torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora - torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora - torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora - torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora - torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora - torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora + torch.nn.Linear.forward = torch.nn.Linear_forward_before_network + torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network + torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network + torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network + torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network + torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network def before_ui(): @@ -23,50 +24,50 @@ def before_ui(): extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora()) -if not hasattr(torch.nn, 'Linear_forward_before_lora'): - torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward +if not hasattr(torch.nn, 'Linear_forward_before_network'): + torch.nn.Linear_forward_before_network = torch.nn.Linear.forward -if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'): - torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict +if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'): + torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict -if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): - torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward +if not hasattr(torch.nn, 'Conv2d_forward_before_network'): + torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward -if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): - torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict +if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'): + torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict -if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'): - torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward +if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'): + torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward -if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'): - torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict +if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'): + torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict -torch.nn.Linear.forward = lora.lora_Linear_forward -torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict -torch.nn.Conv2d.forward = lora.lora_Conv2d_forward -torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict -torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward -torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict +torch.nn.Linear.forward = networks.network_Linear_forward +torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict +torch.nn.Conv2d.forward = networks.network_Conv2d_forward +torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict +torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward +torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict -script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) +script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) script_callbacks.on_before_ui(before_ui) -script_callbacks.on_infotext_pasted(lora.infotext_pasted) +script_callbacks.on_infotext_pasted(networks.infotext_pasted) shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { - "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras), + "sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks), "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), })) shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), { - "lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"), + "lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"), })) -def create_lora_json(obj: lora.LoraOnDisk): +def create_lora_json(obj: network.NetworkOnDisk): return { "name": obj.name, "alias": obj.alias, @@ -75,17 +76,17 @@ def create_lora_json(obj: lora.LoraOnDisk): } -def api_loras(_: gr.Blocks, app: FastAPI): +def api_networks(_: gr.Blocks, app: FastAPI): @app.get("/sdapi/v1/loras") async def get_loras(): - return [create_lora_json(obj) for obj in lora.available_loras.values()] + return [create_lora_json(obj) for obj in networks.available_networks.values()] @app.post("/sdapi/v1/refresh-loras") async def refresh_loras(): - return lora.list_available_loras() + return networks.list_available_networks() -script_callbacks.on_app_started(api_loras) +script_callbacks.on_app_started(api_networks) re_lora = re.compile(" Date: Sun, 16 Jul 2023 23:14:57 +0300 Subject: [PATCH 162/178] linter --- extensions-builtin/Lora/network.py | 4 +--- extensions-builtin/Lora/network_lyco.py | 4 ---- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index a1fe6bbf..4ac63722 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -1,9 +1,7 @@ import os from collections import namedtuple -import torch - -from modules import devices, sd_models, cache, errors, hashes, shared +from modules import sd_models, cache, errors, hashes, shared NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) diff --git a/extensions-builtin/Lora/network_lyco.py b/extensions-builtin/Lora/network_lyco.py index 18a822fa..fc135314 100644 --- a/extensions-builtin/Lora/network_lyco.py +++ b/extensions-builtin/Lora/network_lyco.py @@ -1,8 +1,4 @@ -import torch - -import lyco_helpers import network -from modules import devices class NetworkModuleLyco(network.NetworkModule): From ef5dac7786916dd39711edb2b8e90ce96ef78fca Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 00:01:17 +0300 Subject: [PATCH 163/178] fix --- extensions-builtin/Lora/network_hada.py | 3 --- extensions-builtin/Lora/networks.py | 1 + 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py index 15e7ffd8..799bb3bc 100644 --- a/extensions-builtin/Lora/network_hada.py +++ b/extensions-builtin/Lora/network_hada.py @@ -27,9 +27,6 @@ class NetworkModuleHada(network_lyco.NetworkModuleLyco): self.t1 = weights.w.get("hada_t1") self.t2 = weights.w.get("hada_t2") - self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None - self.scale = weights.w["scale"].item() if "scale" in weights.w else None - def calc_updown(self, orig_weight): w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 5b0ddfb6..90374faa 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -271,6 +271,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) self.weight += updown + continue module_q = net.modules.get(network_layer_name + "_q_proj", None) module_k = net.modules.get(network_layer_name + "_k_proj", None) From 58c3df32f3a73b20ea33d1709a1d25818b8a98dd Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 00:12:18 +0300 Subject: [PATCH 164/178] IA3 support --- extensions-builtin/Lora/network_ia3.py | 32 ++++++++++++++++++++++++++ extensions-builtin/Lora/networks.py | 2 ++ 2 files changed, 34 insertions(+) create mode 100644 extensions-builtin/Lora/network_ia3.py diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py new file mode 100644 index 00000000..99f2307c --- /dev/null +++ b/extensions-builtin/Lora/network_ia3.py @@ -0,0 +1,32 @@ +import lyco_helpers +import network +import network_lyco + + +class ModuleTypeIa3(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["weight"]): + return NetworkModuleIa3(net, weights) + + return None + + +class NetworkModuleIa3(network_lyco.NetworkModuleLyco): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.w = weights.w["weight"] + self.on_input = weights.w["on_input"].item() + + def calc_updown(self, orig_weight): + w = self.w.to(orig_weight.device, dtype=orig_weight.dtype) + + output_shape = [w.size(0), orig_weight.size(1)] + if self.on_input: + output_shape.reverse() + else: + w = w.reshape(-1, 1) + + updown = orig_weight * w + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 90374faa..bf810b5b 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -4,6 +4,7 @@ import re import network import network_lora import network_hada +import network_ia3 import torch from typing import Union @@ -13,6 +14,7 @@ from modules import shared, devices, sd_models, errors, scripts, sd_hijack module_types = [ network_lora.ModuleTypeLora(), network_hada.ModuleTypeHada(), + network_ia3.ModuleTypeIa3(), ] From 46466f09d0b0c14118033dee6af0f876059776d3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 00:29:07 +0300 Subject: [PATCH 165/178] Lokr support --- extensions-builtin/Lora/network_ia3.py | 1 - extensions-builtin/Lora/network_lokr.py | 65 +++++++++++++++++++++++++ extensions-builtin/Lora/networks.py | 2 + 3 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 extensions-builtin/Lora/network_lokr.py diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py index 99f2307c..d8806da0 100644 --- a/extensions-builtin/Lora/network_ia3.py +++ b/extensions-builtin/Lora/network_ia3.py @@ -1,4 +1,3 @@ -import lyco_helpers import network import network_lyco diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py new file mode 100644 index 00000000..f1731924 --- /dev/null +++ b/extensions-builtin/Lora/network_lokr.py @@ -0,0 +1,65 @@ +import torch + +import lyco_helpers +import network +import network_lyco + + +class ModuleTypeLokr(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + has_1 = "lokr_w1" in weights.w or ("lokr_w1a" in weights.w and "lokr_w1b" in weights.w) + has_2 = "lokr_w2" in weights.w or ("lokr_w2a" in weights.w and "lokr_w2b" in weights.w) + if has_1 and has_2: + return NetworkModuleLokr(net, weights) + + return None + + +def make_kron(orig_shape, w1, w2): + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + w2 = w2.contiguous() + return torch.kron(w1, w2).reshape(orig_shape) + + +class NetworkModuleLokr(network_lyco.NetworkModuleLyco): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.w1 = weights.w.get("lokr_w1") + self.w1a = weights.w.get("lokr_w1_a") + self.w1b = weights.w.get("lokr_w1_b") + self.dim = self.w1b.shape[0] if self.w1b else self.dim + self.w2 = weights.w.get("lokr_w2") + self.w2a = weights.w.get("lokr_w2_a") + self.w2b = weights.w.get("lokr_w2_b") + self.dim = self.w2b.shape[0] if self.w2b else self.dim + self.t2 = weights.w.get("lokr_t2") + + def calc_updown(self, orig_weight): + if self.w1 is not None: + w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype) + else: + w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) + w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w1 = w1a @ w1b + + if self.w2 is not None: + w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype) + elif self.t2 is None: + w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) + w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w2 = w2a @ w2b + else: + t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) + w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) + + output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] + if len(orig_weight.shape) == 4: + output_shape = orig_weight.shape + + updown = make_kron(output_shape, w1, w2) + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index bf810b5b..1b358561 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -5,6 +5,7 @@ import network import network_lora import network_hada import network_ia3 +import network_lokr import torch from typing import Union @@ -15,6 +16,7 @@ module_types = [ network_lora.ModuleTypeLora(), network_hada.ModuleTypeHada(), network_ia3.ModuleTypeIa3(), + network_lokr.ModuleTypeLokr(), ] From 7870937c770aaba9e681c299f923ba645163c85c Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 17 Jul 2023 12:25:29 +0900 Subject: [PATCH 166/178] XYZ always_discard_next_to_last_sigma Co-authored-by: Franck Mahon --- scripts/xyz_grid.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 7821cc65..ee30747c 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -144,11 +144,18 @@ def apply_face_restore(p, opt, x): p.restore_faces = is_active -def apply_override(field): +def apply_override(field, boolean: bool = False): def fun(p, x, xs): + if boolean: + x = True if x == "True" else False p.override_settings[field] = x return fun + +def boolean_choice(): + return ["True", "False"] + + def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) @@ -235,6 +242,7 @@ axis_options = [ AxisOption("Face restore", str, apply_face_restore, format_value=format_value), AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')), AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')), + AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice), ] From c03856bfdf30fd0e061caefd60231eb86a983c71 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 17 Jul 2023 12:45:10 +0900 Subject: [PATCH 167/178] reversible boolean_choice order --- scripts/xyz_grid.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index ee30747c..bddc28c7 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -152,8 +152,10 @@ def apply_override(field, boolean: bool = False): return fun -def boolean_choice(): - return ["True", "False"] +def boolean_choice(reverse: bool = False): + def choice(): + return ["False", "True"] if reverse else ["True", "False"] + return choice def format_value_add_label(p, opt, x): @@ -242,7 +244,7 @@ axis_options = [ AxisOption("Face restore", str, apply_face_restore, format_value=format_value), AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')), AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')), - AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice), + AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice(reverse=True)), ] From 8941297ceb3e71fa16fd842b135786b0ebc1b2b1 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 17 Jul 2023 12:45:38 +0900 Subject: [PATCH 168/178] lowercase --- scripts/xyz_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index bddc28c7..1010845e 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -147,7 +147,7 @@ def apply_face_restore(p, opt, x): def apply_override(field, boolean: bool = False): def fun(p, x, xs): if boolean: - x = True if x == "True" else False + x = True if x.lower() == "true" else False p.override_settings[field] = x return fun From 238adeaffb037dedbcefe41e7fd4814a1f17baa2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 09:00:47 +0300 Subject: [PATCH 169/178] support specifying te and unet weights separately update lora code support full module --- .../Lora/extra_networks_lora.py | 22 ++++-- extensions-builtin/Lora/lyco_helpers.py | 6 ++ extensions-builtin/Lora/network.py | 40 ++++++++++- extensions-builtin/Lora/network_full.py | 23 ++++++ extensions-builtin/Lora/network_hada.py | 3 +- extensions-builtin/Lora/network_ia3.py | 3 +- extensions-builtin/Lora/network_lokr.py | 3 +- extensions-builtin/Lora/network_lora.py | 72 +++++++++++-------- extensions-builtin/Lora/network_lyco.py | 35 --------- extensions-builtin/Lora/networks.py | 22 ++++-- 10 files changed, 151 insertions(+), 78 deletions(-) create mode 100644 extensions-builtin/Lora/network_full.py delete mode 100644 extensions-builtin/Lora/network_lyco.py diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index 8a6639cf..084c41d0 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -14,14 +14,28 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) names = [] - multipliers = [] + te_multipliers = [] + unet_multipliers = [] + dyn_dims = [] for params in params_list: assert params.items - names.append(params.items[0]) - multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) + names.append(params.positional[0]) - networks.load_networks(names, multipliers) + te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0 + te_multiplier = float(params.named.get("te", te_multiplier)) + + unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else 1.0 + unet_multiplier = float(params.named.get("unet", unet_multiplier)) + + dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None + dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim + + te_multipliers.append(te_multiplier) + unet_multipliers.append(unet_multiplier) + dyn_dims.append(dyn_dim) + + networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims) if shared.opts.lora_add_hashes_to_infotext: network_hashes = [] diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py index 9ea499fb..279b34bc 100644 --- a/extensions-builtin/Lora/lyco_helpers.py +++ b/extensions-builtin/Lora/lyco_helpers.py @@ -13,3 +13,9 @@ def rebuild_conventional(up, down, shape, dyn_dim=None): up = up[:, :dyn_dim] down = down[:dyn_dim, :] return (up @ down).reshape(shape) + + +def rebuild_cp_decomposition(up, down, mid): + up = up.reshape(up.size(0), -1) + down = down.reshape(down.size(0), -1) + return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 4ac63722..fe42dbdd 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -68,7 +68,9 @@ class Network: # LoraModule def __init__(self, name, network_on_disk: NetworkOnDisk): self.name = name self.network_on_disk = network_on_disk - self.multiplier = 1.0 + self.te_multiplier = 1.0 + self.unet_multiplier = 1.0 + self.dyn_dim = None self.modules = {} self.mtime = None @@ -88,6 +90,42 @@ class NetworkModule: self.sd_key = weights.sd_key self.sd_module = weights.sd_module + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.dim = None + self.bias = weights.w.get("bias") + self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None + self.scale = weights.w["scale"].item() if "scale" in weights.w else None + + def multiplier(self): + if 'transformer' in self.sd_key[:20]: + return self.network.te_multiplier + else: + return self.network.unet_multiplier + + def calc_scale(self): + if self.scale is not None: + return self.scale + if self.dim is not None and self.alpha is not None: + return self.alpha / self.dim + + return 1.0 + + def finalize_updown(self, updown, orig_weight, output_shape): + if self.bias is not None: + updown = updown.reshape(self.bias.shape) + updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown = updown.reshape(output_shape) + + if len(output_shape) == 4: + updown = updown.reshape(output_shape) + + if orig_weight.size().numel() == updown.size().numel(): + updown = updown.reshape(orig_weight.shape) + + return updown * self.calc_scale() * self.multiplier() + def calc_updown(self, target): raise NotImplementedError() diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py new file mode 100644 index 00000000..f0d8a6e0 --- /dev/null +++ b/extensions-builtin/Lora/network_full.py @@ -0,0 +1,23 @@ +import lyco_helpers +import network + + +class ModuleTypeFull(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["diff"]): + return NetworkModuleFull(net, weights) + + return None + + +class NetworkModuleFull(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.weight = weights.w.get("diff") + + def calc_updown(self, orig_weight): + output_shape = self.weight.shape + updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py index 799bb3bc..5fcb0695 100644 --- a/extensions-builtin/Lora/network_hada.py +++ b/extensions-builtin/Lora/network_hada.py @@ -1,6 +1,5 @@ import lyco_helpers import network -import network_lyco class ModuleTypeHada(network.ModuleType): @@ -11,7 +10,7 @@ class ModuleTypeHada(network.ModuleType): return None -class NetworkModuleHada(network_lyco.NetworkModuleLyco): +class NetworkModuleHada(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py index d8806da0..7edc4249 100644 --- a/extensions-builtin/Lora/network_ia3.py +++ b/extensions-builtin/Lora/network_ia3.py @@ -1,5 +1,4 @@ import network -import network_lyco class ModuleTypeIa3(network.ModuleType): @@ -10,7 +9,7 @@ class ModuleTypeIa3(network.ModuleType): return None -class NetworkModuleIa3(network_lyco.NetworkModuleLyco): +class NetworkModuleIa3(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py index f1731924..920062e2 100644 --- a/extensions-builtin/Lora/network_lokr.py +++ b/extensions-builtin/Lora/network_lokr.py @@ -2,7 +2,6 @@ import torch import lyco_helpers import network -import network_lyco class ModuleTypeLokr(network.ModuleType): @@ -22,7 +21,7 @@ def make_kron(orig_shape, w1, w2): return torch.kron(w1, w2).reshape(orig_shape) -class NetworkModuleLokr(network_lyco.NetworkModuleLyco): +class NetworkModuleLokr(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index b2d96537..26c0a72c 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -1,5 +1,6 @@ import torch +import lyco_helpers import network from modules import devices @@ -16,29 +17,42 @@ class NetworkModuleLora(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) - self.up = self.create_module(weights.w["lora_up.weight"]) - self.down = self.create_module(weights.w["lora_down.weight"]) - self.alpha = weights.w["alpha"] if "alpha" in weights.w else None + self.up_model = self.create_module(weights.w, "lora_up.weight") + self.down_model = self.create_module(weights.w, "lora_down.weight") + self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True) + + self.dim = weights.w["lora_down.weight"].shape[0] + + def create_module(self, weights, key, none_ok=False): + weight = weights.get(key) - def create_module(self, weight, none_ok=False): if weight is None and none_ok: return None - if type(self.sd_module) == torch.nn.Linear: + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + + if is_linear: + weight = weight.reshape(weight.shape[0], -1) module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(self.sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(self.sd_module) == torch.nn.MultiheadAttention: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1): + elif is_conv and key == "lora_down.weight" or key == "dyn_up": + if len(weight.shape) == 2: + weight = weight.reshape(weight.shape[0], -1, 1, 1) + + if weight.shape[2] != 1 or weight.shape[3] != 1: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + else: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + elif is_conv and key == "lora_mid.weight": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + elif is_conv and key == "lora_up.weight" or key == "dyn_down": module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) - elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3): - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False) else: - print(f'Network layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') - return None + raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') with torch.no_grad(): + if weight.shape != module.weight.shape: + weight = weight.reshape(module.weight.shape) module.weight.copy_(weight) module.to(device=devices.cpu, dtype=devices.dtype) @@ -46,25 +60,27 @@ class NetworkModuleLora(network.NetworkModule): return module - def calc_updown(self, target): - up = self.up.weight.to(target.device, dtype=target.dtype) - down = self.down.weight.to(target.device, dtype=target.dtype) + def calc_updown(self, orig_weight): + up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) - if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): - updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) - elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): - updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + output_shape = [up.size(0), down.size(1)] + if self.mid_model is not None: + # cp-decomposition + mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) + output_shape += mid.shape[2:] else: - updown = up @ down + if len(down.shape) == 4: + output_shape += down.shape[2:] + updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim) - updown = updown * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0) - - return updown + return self.finalize_updown(updown, orig_weight, output_shape) def forward(self, x, y): - self.up.to(device=devices.device) - self.down.to(device=devices.device) + self.up_model.to(device=devices.device) + self.down_model.to(device=devices.device) - return y + self.up(self.down(x)) * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0) + return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale() diff --git a/extensions-builtin/Lora/network_lyco.py b/extensions-builtin/Lora/network_lyco.py deleted file mode 100644 index fc135314..00000000 --- a/extensions-builtin/Lora/network_lyco.py +++ /dev/null @@ -1,35 +0,0 @@ -import network - - -class NetworkModuleLyco(network.NetworkModule): - def __init__(self, net: network.Network, weights: network.NetworkWeights): - super().__init__(net, weights) - - if hasattr(self.sd_module, 'weight'): - self.shape = self.sd_module.weight.shape - - self.dim = None - self.bias = weights.w.get("bias") - self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None - self.scale = weights.w["scale"].item() if "scale" in weights.w else None - - def finalize_updown(self, updown, orig_weight, output_shape): - if self.bias is not None: - updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) - updown = updown.reshape(output_shape) - - if len(output_shape) == 4: - updown = updown.reshape(output_shape) - - if orig_weight.size().numel() == updown.size().numel(): - updown = updown.reshape(orig_weight.shape) - - scale = ( - self.scale if self.scale is not None - else self.alpha / self.dim if self.dim is not None and self.alpha is not None - else 1.0 - ) - - return updown * scale * self.network.multiplier - diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 1b358561..401430e8 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -6,6 +6,7 @@ import network_lora import network_hada import network_ia3 import network_lokr +import network_full import torch from typing import Union @@ -17,6 +18,7 @@ module_types = [ network_hada.ModuleTypeHada(), network_ia3.ModuleTypeIa3(), network_lokr.ModuleTypeLokr(), + network_full.ModuleTypeFull(), ] @@ -52,6 +54,15 @@ def convert_diffusers_name_to_compvis(key, is_sd2): m = [] + if match(m, r"lora_unet_conv_in(.*)"): + return f'diffusion_model_input_blocks_0_0{m[0]}' + + if match(m, r"lora_unet_conv_out(.*)"): + return f'diffusion_model_out_2{m[0]}' + + if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): + return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}" + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" @@ -179,7 +190,7 @@ def load_network(name, network_on_disk): return net -def load_networks(names, multipliers=None): +def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): already_loaded = {} for net in loaded_networks: @@ -218,7 +229,9 @@ def load_networks(names, multipliers=None): print(f"Couldn't find network with name {name}") continue - net.multiplier = multipliers[i] if multipliers else 1.0 + net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 + net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0 + net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 loaded_networks.append(net) if failed_to_load_networks: @@ -250,7 +263,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn return current_names = getattr(self, "network_current_names", ()) - wanted_names = tuple((x.name, x.multiplier) for x in loaded_networks) + wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) weights_backup = getattr(self, "network_weights_backup", None) if weights_backup is None: @@ -288,9 +301,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn updown_k = module_k.calc_updown(self.in_proj_weight) updown_v = module_v.calc_updown(self.in_proj_weight) updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) + updown_out = module_out.calc_updown(self.out_proj.weight) self.in_proj_weight += updown_qkv - self.out_proj.weight += module_out.calc_updown(self.out_proj.weight) + self.out_proj.weight += updown_out continue if module is None: From 2e07a8ae6b1d92838b3a8a0f6eaf5fcf4a92d48f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 09:05:18 +0300 Subject: [PATCH 170/178] some backwards compatibility linter --- extensions-builtin/Lora/lora.py | 9 +++++++++ extensions-builtin/Lora/network_full.py | 1 - extensions-builtin/Lora/scripts/lora_script.py | 1 + 3 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 extensions-builtin/Lora/lora.py diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py new file mode 100644 index 00000000..9365aa74 --- /dev/null +++ b/extensions-builtin/Lora/lora.py @@ -0,0 +1,9 @@ +import networks + +list_available_loras = networks.list_available_networks + +available_loras = networks.available_networks +available_lora_aliases = networks.available_network_aliases +available_lora_hash_lookup = networks.available_network_hash_lookup +forbidden_lora_aliases = networks.forbidden_network_aliases +loaded_loras = networks.loaded_networks diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py index f0d8a6e0..109b4c2c 100644 --- a/extensions-builtin/Lora/network_full.py +++ b/extensions-builtin/Lora/network_full.py @@ -1,4 +1,3 @@ -import lyco_helpers import network diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 81e6572a..4c75821e 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -6,6 +6,7 @@ from fastapi import FastAPI import network import networks +import lora # noqa:F401 import extra_networks_lora import ui_extra_networks_lora from modules import script_callbacks, ui_extra_networks, extra_networks, shared From 9251ae3bc78e465058c286e86f3c26cb6f819a31 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 09:29:36 +0300 Subject: [PATCH 171/178] delay writing cache to prevent writing the same thing over and over --- modules/cache.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/modules/cache.py b/modules/cache.py index ddf44637..71fe6302 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -1,6 +1,7 @@ import json import os.path import threading +import time from modules.paths import data_path, script_path @@ -8,15 +9,37 @@ cache_filename = os.path.join(data_path, "cache.json") cache_data = None cache_lock = threading.Lock() +dump_cache_after = None +dump_cache_thread = None + def dump_cache(): """ - Saves all cache data to a file. + Marks cache for writing to disk. 5 seconds after no one else flags the cache for writing, it is written. """ + global dump_cache_after + global dump_cache_thread + + def thread_func(): + global dump_cache_after + global dump_cache_thread + + while dump_cache_after is not None and time.time() < dump_cache_after: + time.sleep(1) + + with cache_lock: + with open(cache_filename, "w", encoding="utf8") as file: + json.dump(cache_data, file, indent=4) + + dump_cache_after = None + dump_cache_thread = None + with cache_lock: - with open(cache_filename, "w", encoding="utf8") as file: - json.dump(cache_data, file, indent=4) + dump_cache_after = time.time() + 5 + if dump_cache_thread is None: + dump_cache_thread = threading.Thread(name='cache-writer', target=thread_func) + dump_cache_thread.start() def cache(subsection): From 35510f7529dc05437a82496187ef06b852be9ab1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 10:06:02 +0300 Subject: [PATCH 172/178] add alias to lyco network read networks from LyCORIS dir if it exists add credits --- README.md | 1 + extensions-builtin/Lora/networks.py | 3 ++- extensions-builtin/Lora/scripts/lora_script.py | 5 ++++- modules/extra_networks.py | 16 ++++++++++++++-- 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e6d8e4bd..b796d150 100644 --- a/README.md +++ b/README.md @@ -168,5 +168,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - Security advice - RyotaK - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC - TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd +- LyCORIS - KohakuBlueleaf - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - (You) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 401430e8..7b4c0312 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -11,7 +11,7 @@ import network_full import torch from typing import Union -from modules import shared, devices, sd_models, errors, scripts, sd_hijack +from modules import shared, devices, sd_models, errors, scripts, sd_hijack, paths module_types = [ network_lora.ModuleTypeLora(), @@ -399,6 +399,7 @@ def list_available_networks(): os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) + candidates += list(shared.walk_files(os.path.join(paths.models_path, "LyCORIS"), allowed_extensions=[".pt", ".ckpt", ".safetensors"])) for filename in candidates: if os.path.isdir(filename): continue diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 4c75821e..f478f718 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -22,7 +22,10 @@ def unload(): def before_ui(): ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) - extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora()) + + extra_network = extra_networks_lora.ExtraNetworkLora() + extra_networks.register_extra_network(extra_network) + extra_networks.register_extra_network_alias(extra_network, "lyco") if not hasattr(torch.nn, 'Linear_forward_before_network'): diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 41799b0a..6ae07e91 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -4,16 +4,22 @@ from collections import defaultdict from modules import errors extra_network_registry = {} +extra_network_aliases = {} def initialize(): extra_network_registry.clear() + extra_network_aliases.clear() def register_extra_network(extra_network): extra_network_registry[extra_network.name] = extra_network +def register_extra_network_alias(extra_network, alias): + extra_network_aliases[alias] = extra_network + + def register_default_extra_networks(): from modules.extra_networks_hypernet import ExtraNetworkHypernet register_extra_network(ExtraNetworkHypernet()) @@ -82,20 +88,26 @@ def activate(p, extra_network_data): """call activate for extra networks in extra_network_data in specified order, then call activate for all remaining registered networks with an empty argument list""" + activated = [] + for extra_network_name, extra_network_args in extra_network_data.items(): extra_network = extra_network_registry.get(extra_network_name, None) + + if extra_network is None: + extra_network = extra_network_aliases.get(extra_network_name, None) + if extra_network is None: print(f"Skipping unknown extra network: {extra_network_name}") continue try: extra_network.activate(p, extra_network_args) + activated.append(extra_network) except Exception as e: errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}") for extra_network_name, extra_network in extra_network_registry.items(): - args = extra_network_data.get(extra_network_name, None) - if args is not None: + if extra_network in activated: continue try: From 05d23c78376ce73d3de932c7e7b8871914295675 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 11:44:29 +0300 Subject: [PATCH 173/178] move generate button below the picture for mobile clients --- .../mobile/javascript/mobile.js | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 extensions-builtin/mobile/javascript/mobile.js diff --git a/extensions-builtin/mobile/javascript/mobile.js b/extensions-builtin/mobile/javascript/mobile.js new file mode 100644 index 00000000..12cae4b7 --- /dev/null +++ b/extensions-builtin/mobile/javascript/mobile.js @@ -0,0 +1,26 @@ +var isSetupForMobile = false; + +function isMobile() { + for (var tab of ["txt2img", "img2img"]) { + var imageTab = gradioApp().getElementById(tab + '_results'); + if (imageTab && imageTab.offsetParent && imageTab.offsetLeft == 0) { + return true; + } + } + + return false; +} + +function reportWindowSize() { + var currentlyMobile = isMobile(); + if (currentlyMobile == isSetupForMobile) return; + isSetupForMobile = currentlyMobile; + + for (var tab of ["txt2img", "img2img"]) { + var button = gradioApp().getElementById(tab + '_generate_box'); + var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column'); + target.insertBefore(button, target.firstElementChild); + } +} + +window.addEventListener("resize", reportWindowSize); From 699108bfbb05c2a7d2ee4a2c7abcfaa0a244d8ea Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 18:56:14 +0300 Subject: [PATCH 174/178] hide cards for networks of incompatible stable diffusion version in Lora extra networks interface --- extensions-builtin/Lora/network.py | 20 +++++++++++ .../Lora/scripts/lora_script.py | 2 ++ .../Lora/ui_edit_user_metadata.py | 20 ++++++++--- .../Lora/ui_extra_networks_lora.py | 34 ++++++++++++++++--- html/extra-networks-card.html | 2 +- javascript/extraNetworks.js | 2 +- modules/sd_models.py | 3 ++ modules/ui_extra_networks.py | 3 +- modules/ui_extra_networks_user_metadata.py | 7 +++- style.css | 6 +++- 10 files changed, 84 insertions(+), 15 deletions(-) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index fe42dbdd..8ecfa29a 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -1,5 +1,6 @@ import os from collections import namedtuple +import enum from modules import sd_models, cache, errors, hashes, shared @@ -8,6 +9,13 @@ NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} +class SdVersion(enum.Enum): + Unknown = 1 + SD1 = 2 + SD2 = 3 + SDXL = 4 + + class NetworkOnDisk: def __init__(self, name, filename): self.name = name @@ -44,6 +52,18 @@ class NetworkOnDisk: '' ) + self.sd_version = self.detect_version() + + def detect_version(self): + if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"): + return SdVersion.SDXL + elif str(self.metadata.get('ss_v2', "")) == "True": + return SdVersion.SD2 + elif len(self.metadata): + return SdVersion.SD1 + + return SdVersion.Unknown + def set_hash(self, v): self.hash = v self.shorthash = self.hash[0:12] diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index f478f718..cd28afc9 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -63,6 +63,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra "sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks), "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), + "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), + "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), })) diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index 354a1d68..c8730443 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -46,14 +46,17 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) def __init__(self, ui, tabname, page): super().__init__(ui, tabname, page) + self.select_sd_version = None + self.taginfo = None self.edit_activation_text = None self.slider_preferred_weight = None self.edit_notes = None - def save_lora_user_metadata(self, name, desc, activation_text, preferred_weight, notes): + def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc + user_metadata["sd version"] = sd_version user_metadata["activation text"] = activation_text user_metadata["preferred weight"] = preferred_weight user_metadata["notes"] = notes @@ -112,11 +115,11 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]] return [ - *values[0:4], + *values[0:5], + item.get("sd_version", "Unknown"), gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), user_metadata.get('activation text', ''), float(user_metadata.get('preferred weight', 0.0)), - user_metadata.get('notes', ''), gr.update(visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), ] @@ -141,10 +144,15 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) return ", ".join(sorted(res)) + def create_extra_default_items_in_left_column(self): + + # this would be a lot better as gr.Radio but I can't make it work + self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True) + def create_editor(self): self.create_default_editor_elems() - self.taginfo = gr.HighlightedText(label="Tags") + self.taginfo = gr.HighlightedText(label="Training dataset tags") self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) @@ -178,10 +186,11 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.edit_description, self.html_filedata, self.html_preview, + self.edit_notes, + self.select_sd_version, self.taginfo, self.edit_activation_text, self.slider_preferred_weight, - self.edit_notes, row_random_prompt, random_prompt, ] @@ -192,6 +201,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) edited_components = [ self.edit_description, + self.select_sd_version, self.edit_activation_text, self.slider_preferred_weight, self.edit_notes, diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index b6171a26..4b32098b 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -1,7 +1,9 @@ import os + +import network import networks -from modules import shared, ui_extra_networks +from modules import shared, ui_extra_networks, paths from modules.ui_extra_networks import quote_js from ui_edit_user_metadata import LoraUserMetadataEditor @@ -13,14 +15,13 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): def refresh(self): networks.list_available_networks() - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): lora_on_disk = networks.available_networks.get(name) path, ext = os.path.splitext(lora_on_disk.filename) alias = lora_on_disk.get_alias() - # in 1.5 filename changes to be full filename instead of path without extension, and metadata is dict instead of json string item = { "name": name, "filename": lora_on_disk.filename, @@ -30,6 +31,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": lora_on_disk.metadata, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, + "sd_version": lora_on_disk.sd_version.name, } self.read_user_metadata(item) @@ -40,15 +42,37 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): if activation_text: item["prompt"] += " + " + quote_js(" " + activation_text) + sd_version = item["user_metadata"].get("sd version") + if sd_version in network.SdVersion.__members__: + item["sd_version"] = sd_version + sd_version = network.SdVersion[sd_version] + else: + sd_version = lora_on_disk.sd_version + + if shared.opts.lora_show_all or not enable_filter: + pass + elif sd_version == network.SdVersion.Unknown: + model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1 + if model_version.name in shared.opts.lora_hide_unknown_for_versions: + return None + elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL: + return None + elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2: + return None + elif shared.sd_model.is_sd1 and sd_version != network.SdVersion.SD1: + return None + return item def list_items(self): for index, name in enumerate(networks.available_networks): item = self.create_item(name, index) - yield item + + if item is not None: + yield item def allowed_directories_for_previews(self): - return [shared.cmd_opts.lora_dir] + return [shared.cmd_opts.lora_dir, os.path.join(paths.models_path, "LyCORIS")] def create_user_metadata_editor(self, ui, tabname): return LoraUserMetadataEditor(ui, tabname, self) diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index eb8b1a67..39674666 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,8 +1,8 @@
{background_image}
- {edit_button} {metadata_button} + {edit_button}
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index e453094a..5582a6e5 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -213,7 +213,7 @@ function popup(contents) { globalPopupInner.classList.add('global-popup-inner'); globalPopup.appendChild(globalPopupInner); - gradioApp().appendChild(globalPopup); + gradioApp().querySelector('.main').appendChild(globalPopup); } globalPopupInner.innerHTML = ''; diff --git a/modules/sd_models.py b/modules/sd_models.py index 729f03d7..4d9382dd 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -290,6 +290,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer state_dict = get_checkpoint_state_dict(checkpoint_info, timer) model.is_sdxl = hasattr(model, 'conditioner') + model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') + model.is_sd1 = not model.is_sdxl and not model.is_sd2 + if model.is_sdxl: sd_models_xl.extend_sdxl(model) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 6c73998f..49612298 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -62,7 +62,8 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""): page = next(iter([x for x in extra_pages if x.name == page]), None) try: - item = page.create_item(name) + item = page.create_item(name, enable_filter=False) + page.items[name] = item except Exception as e: errors.display(e, "creating item for extra network") item = page.items.get(name) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 01ff4e4b..63d4b503 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -42,6 +42,9 @@ class UserMetadataEditor: return user_metadata + def create_extra_default_items_in_left_column(self): + pass + def create_default_editor_elems(self): with gr.Row(): with gr.Column(scale=2): @@ -49,6 +52,8 @@ class UserMetadataEditor: self.edit_description = gr.Textbox(label="Description", lines=4) self.html_filedata = gr.HTML() + self.create_extra_default_items_in_left_column() + with gr.Column(scale=1, min_width=0): self.html_preview = gr.HTML() @@ -111,7 +116,7 @@ class UserMetadataEditor: table = '' + "".join(f"" for name, value in params) + '' - return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', ''), + return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '') def write_user_metadata(self, name, metadata): item = self.page.items.get(name, {}) diff --git a/style.css b/style.css index 8a66c3d2..e249cfd3 100644 --- a/style.css +++ b/style.css @@ -841,7 +841,7 @@ footer { .extra-network-cards .card .card-button { text-shadow: 2px 2px 3px black; - padding: 0.25em; + padding: 0.25em 0.1em; font-size: 200%; width: 1.5em; } @@ -957,6 +957,10 @@ div.block.gradio-box.edit-user-metadata { text-align: left; } +.edit-user-metadata .file-metadata th, .edit-user-metadata .file-metadata td{ + padding: 0.3em 1em; +} + .edit-user-metadata .wrap.translucent{ background: var(--body-background-fill); } From a99d5708e6d603e8f7cfd1b8c6595f8026219ba0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 20:10:24 +0300 Subject: [PATCH 175/178] skip installing packages with pip if theyare already installed record time it took to launch --- modules/launch_utils.py | 46 ++++++++++++++++++++++++++++++++++++++- requirements_versions.txt | 4 ++-- webui.py | 9 ++++---- 3 files changed, 52 insertions(+), 7 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 434facbc..03552bc2 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -1,4 +1,5 @@ # this scripts installs necessary requirements and launches main program in webui.py +import re import subprocess import os import sys @@ -9,6 +10,9 @@ from functools import lru_cache from modules import cmd_args, errors from modules.paths_internal import script_path, extensions_dir +from modules import timer + +timer.startup_timer.record("start") args, _ = cmd_args.parser.parse_known_args() @@ -226,6 +230,44 @@ def run_extensions_installers(settings_file): run_extension_installer(os.path.join(extensions_dir, dirname_extension)) +re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*") + + +def requrements_met(requirements_file): + """ + Does a simple parse of a requirements.txt file to determine if all rerqirements in it + are already installed. Returns True if so, False if not installed or parsing fails. + """ + + import importlib.metadata + import packaging.version + + with open(requirements_file, "r", encoding="utf8") as file: + for line in file: + if line.strip() == "": + continue + + m = re.match(re_requirement, line) + if m is None: + return False + + package = m.group(1).strip() + version_required = (m.group(2) or "").strip() + + if version_required == "": + continue + + try: + version_installed = importlib.metadata.version(package) + except Exception: + return False + + if packaging.version.parse(version_required) != packaging.version.parse(version_installed): + return False + + return True + + def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") @@ -311,7 +353,9 @@ def prepare_environment(): if not os.path.isfile(requirements_file): requirements_file = os.path.join(script_path, requirements_file) - run_pip(f"install -r \"{requirements_file}\"", "requirements") + + if not requrements_met(requirements_file): + run_pip(f"install -r \"{requirements_file}\"", "requirements") run_extensions_installers(settings_file=args.ui_settings_file) diff --git a/requirements_versions.txt b/requirements_versions.txt index b826bf43..d07ab456 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -8,7 +8,7 @@ einops==0.4.1 fastapi==0.94.0 gfpgan==1.3.8 gradio==3.32.0 -httpcore<=0.15 +httpcore==0.15 inflection==0.5.1 jsonmerge==1.8.0 kornia==0.6.7 @@ -17,7 +17,7 @@ numpy==1.23.5 omegaconf==2.2.3 open-clip-torch==2.20.0 piexif==1.1.3 -psutil~=5.9.5 +psutil==5.9.5 pytorch_lightning==1.9.4 realesrgan==0.3.0 resize-right==0.0.2 diff --git a/webui.py b/webui.py index 34c2fd18..2aafc09f 100644 --- a/webui.py +++ b/webui.py @@ -31,21 +31,22 @@ if log_level: logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) -from modules import paths, timer, import_hook, errors, devices # noqa: F401 - +from modules import timer startup_timer = timer.startup_timer +startup_timer.record("launcher") import torch import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") - - startup_timer.record("import torch") import gradio # noqa: F401 startup_timer.record("import gradio") +from modules import paths, timer, import_hook, errors, devices # noqa: F401 +startup_timer.record("setup paths") + import ldm.modules.encoders.modules # noqa: F401 startup_timer.record("import ldm") From 17e14ed2d9451859325d275ccc6cdf51fc85a56d Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 18 Jul 2023 10:23:41 +0800 Subject: [PATCH 176/178] Fix wrong key name in lokr module --- extensions-builtin/Lora/network_lokr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py index 920062e2..3a94f3e9 100644 --- a/extensions-builtin/Lora/network_lokr.py +++ b/extensions-builtin/Lora/network_lokr.py @@ -6,8 +6,8 @@ import network class ModuleTypeLokr(network.ModuleType): def create_module(self, net: network.Network, weights: network.NetworkWeights): - has_1 = "lokr_w1" in weights.w or ("lokr_w1a" in weights.w and "lokr_w1b" in weights.w) - has_2 = "lokr_w2" in weights.w or ("lokr_w2a" in weights.w and "lokr_w2b" in weights.w) + has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w) + has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w) if has_1 and has_2: return NetworkModuleLokr(net, weights) From 3d31caf4a53c4bb4469b72790b459eba7b251da9 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 18 Jul 2023 10:45:42 +0800 Subject: [PATCH 177/178] use "is not None" for Tensor --- extensions-builtin/Lora/network_lokr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py index 3a94f3e9..340acdab 100644 --- a/extensions-builtin/Lora/network_lokr.py +++ b/extensions-builtin/Lora/network_lokr.py @@ -28,11 +28,11 @@ class NetworkModuleLokr(network.NetworkModule): self.w1 = weights.w.get("lokr_w1") self.w1a = weights.w.get("lokr_w1_a") self.w1b = weights.w.get("lokr_w1_b") - self.dim = self.w1b.shape[0] if self.w1b else self.dim + self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim self.w2 = weights.w.get("lokr_w2") self.w2a = weights.w.get("lokr_w2_a") self.w2b = weights.w.get("lokr_w2_b") - self.dim = self.w2b.shape[0] if self.w2b else self.dim + self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim self.t2 = weights.w.get("lokr_t2") def calc_updown(self, orig_weight): From f0e2098f1a533c88396536282c1d6cd7d847a51c Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 17 Jul 2023 23:39:38 -0400 Subject: [PATCH 178/178] Add support for `--upcast-sampling` with SD XL --- modules/sd_hijack_unet.py | 8 +++++++- modules/sd_models.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index ca1daf45..2101f1a0 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -39,7 +39,10 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): if isinstance(cond, dict): for y in cond.keys(): - cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + if isinstance(cond[y], list): + cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + else: + cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] with devices.autocast(): return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() @@ -77,3 +80,6 @@ first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devi CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) + +CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast) +CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) diff --git a/modules/sd_models.py b/modules/sd_models.py index 4d9382dd..5813b550 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -326,7 +326,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer timer.record("apply half()") - devices.dtype_unet = model.model.diffusion_model.dtype + devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 model.first_stage_model.to(devices.dtype_vae)