diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 6d860224..137e58f7 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -57,6 +57,7 @@ class LoraUpDownModule: def __init__(self): self.up = None self.down = None + self.alpha = None def assign_lora_names_to_compvis_modules(sd_model): @@ -92,6 +93,15 @@ def load_lora(name, filename): keys_failed_to_match.append(key_diffusers) continue + lora_module = lora.modules.get(key, None) + if lora_module is None: + lora_module = LoraUpDownModule() + lora.modules[key] = lora_module + + if lora_key == "alpha": + lora_module.alpha = weight.item() + continue + if type(sd_module) == torch.nn.Linear: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.Conv2d: @@ -104,17 +114,12 @@ def load_lora(name, filename): module.to(device=devices.device, dtype=devices.dtype) - lora_module = lora.modules.get(key, None) - if lora_module is None: - lora_module = LoraUpDownModule() - lora.modules[key] = lora_module - if lora_key == "lora_up.weight": lora_module.up = module elif lora_key == "lora_down.weight": lora_module.down = module else: - assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight' + assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha' if len(keys_failed_to_match) > 0: print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") @@ -161,7 +166,7 @@ def lora_forward(module, input, res): for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is not None: - res = res + module.up(module.down(input)) * lora.multiplier + res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) return res @@ -177,12 +182,12 @@ def lora_Conv2d_forward(self, input): def list_available_loras(): available_loras.clear() - os.makedirs(lora_dir, exist_ok=True) + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) candidates = \ - glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + \ - glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + \ - glob.glob(os.path.join(lora_dir, '**/*.ckpt'), recursive=True) + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True) for filename in sorted(candidates): if os.path.isdir(filename): @@ -193,7 +198,6 @@ def list_available_loras(): available_loras[name] = LoraOnDisk(name, filename) -lora_dir = os.path.join(shared.models_path, "Lora") available_loras = {} loaded_loras = [] diff --git a/extensions-builtin/Lora/preload.py b/extensions-builtin/Lora/preload.py new file mode 100644 index 00000000..863dc5c0 --- /dev/null +++ b/extensions-builtin/Lora/preload.py @@ -0,0 +1,6 @@ +import os +from modules import paths + + +def preload(parser): + parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora')) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 65397890..54a80d36 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -1,3 +1,4 @@ +import json import os import lora @@ -26,10 +27,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": preview, - "prompt": f"", + "prompt": json.dumps(f""), "local_preview": path + ".png", } def allowed_directories_for_previews(self): - return [lora.lora_dir] + return [shared.cmd_opts.lora_dir] diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 9a74b253..e8783bca 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -8,7 +8,7 @@ from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared -from modules.shared import cmd_opts, opts +from modules.shared import cmd_opts, opts, state from swinir_model_arch import SwinIR as net from swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData @@ -145,7 +145,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale): with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: + if state.interrupted or state.skipped: + break + for w_idx in w_idx_list: + if state.interrupted or state.skipped: + break + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] out_patch = model(in_patch) out_patch_mask = torch.ones_like(out_patch) diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index 1bdf1d27..aa9fca87 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,8 +1,8 @@ -
+
{name} diff --git a/html/image-update.svg b/html/image-update.svg new file mode 100644 index 00000000..3abf12df --- /dev/null +++ b/html/image-update.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index cec6a530..619bb1fa 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -1,74 +1,96 @@ -addEventListener('keydown', (event) => { +function keyupEditAttention(event){ let target = event.originalTarget || event.composedPath()[0]; if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return; if (! (event.metaKey || event.ctrlKey)) return; - - let plus = "ArrowUp" - let minus = "ArrowDown" - if (event.key != plus && event.key != minus) return; + let isPlus = event.key == "ArrowUp" + let isMinus = event.key == "ArrowDown" + if (!isPlus && !isMinus) return; let selectionStart = target.selectionStart; let selectionEnd = target.selectionEnd; - // If the user hasn't selected anything, let's select their current parenthesis block - if (selectionStart === selectionEnd) { + let text = target.value; + + function selectCurrentParenthesisBlock(OPEN, CLOSE){ + if (selectionStart !== selectionEnd) return false; + // Find opening parenthesis around current cursor - const before = target.value.substring(0, selectionStart); - let beforeParen = before.lastIndexOf("("); - if (beforeParen == -1) return; - let beforeParenClose = before.lastIndexOf(")"); + const before = text.substring(0, selectionStart); + let beforeParen = before.lastIndexOf(OPEN); + if (beforeParen == -1) return false; + let beforeParenClose = before.lastIndexOf(CLOSE); while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { - beforeParen = before.lastIndexOf("(", beforeParen - 1); - beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1); + beforeParen = before.lastIndexOf(OPEN, beforeParen - 1); + beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1); } // Find closing parenthesis around current cursor - const after = target.value.substring(selectionStart); - let afterParen = after.indexOf(")"); - if (afterParen == -1) return; - let afterParenOpen = after.indexOf("("); + const after = text.substring(selectionStart); + let afterParen = after.indexOf(CLOSE); + if (afterParen == -1) return false; + let afterParenOpen = after.indexOf(OPEN); while (afterParenOpen !== -1 && afterParen > afterParenOpen) { - afterParen = after.indexOf(")", afterParen + 1); - afterParenOpen = after.indexOf("(", afterParenOpen + 1); + afterParen = after.indexOf(CLOSE, afterParen + 1); + afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1); } - if (beforeParen === -1 || afterParen === -1) return; + if (beforeParen === -1 || afterParen === -1) return false; // Set the selection to the text between the parenthesis - const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen); + const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); const lastColon = parenContent.lastIndexOf(":"); selectionStart = beforeParen + 1; selectionEnd = selectionStart + lastColon; target.setSelectionRange(selectionStart, selectionEnd); - } + return true; + } + + // If the user hasn't selected anything, let's select their current parenthesis block + if(! selectCurrentParenthesisBlock('<', '>')){ + selectCurrentParenthesisBlock('(', ')') + } event.preventDefault(); - if (selectionStart == 0 || target.value[selectionStart - 1] != "(") { - target.value = target.value.slice(0, selectionStart) + - "(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" + - target.value.slice(selectionEnd); + closeCharacter = ')' + delta = opts.keyedit_precision_attention - target.focus(); - target.selectionStart = selectionStart + 1; - target.selectionEnd = selectionEnd + 1; + if (selectionStart > 0 && text[selectionStart - 1] == '<'){ + closeCharacter = '>' + delta = opts.keyedit_precision_extra + } else if (selectionStart == 0 || text[selectionStart - 1] != "(") { - } else { - end = target.value.slice(selectionEnd + 1).indexOf(")") + 1; - weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end)); - if (isNaN(weight)) return; - if (event.key == minus) weight -= 0.1; - if (event.key == plus) weight += 0.1; + // do not include spaces at the end + while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){ + selectionEnd -= 1; + } + if(selectionStart == selectionEnd){ + return + } - weight = parseFloat(weight.toPrecision(12)); + text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd); - target.value = target.value.slice(0, selectionEnd + 1) + - weight + - target.value.slice(selectionEnd + 1 + end - 1); + selectionStart += 1; + selectionEnd += 1; + } - target.focus(); - target.selectionStart = selectionStart; - target.selectionEnd = selectionEnd; - } + end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; + weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end)); + if (isNaN(weight)) return; + + weight += isPlus ? delta : -delta; + weight = parseFloat(weight.toPrecision(12)); + if(String(weight).length == 1) weight += ".0" + + text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1); + + target.focus(); + target.value = text; + target.selectionStart = selectionStart; + target.selectionEnd = selectionEnd; updateInput(target) -}); +} + +addEventListener('keydown', (event) => { + keyupEditAttention(event); +}); \ No newline at end of file diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 54ded58c..c5a9adb3 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -13,10 +13,10 @@ function setupExtraNetworksForTab(tabname){ tabs.appendChild(close) search.addEventListener("input", function(evt){ - searchTerm = search.value + searchTerm = search.value.toLowerCase() gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){ - text = elem.querySelector('.name').textContent + text = elem.querySelector('.name').textContent.toLowerCase() elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : "" }) }); diff --git a/javascript/hints.js b/javascript/hints.js index ef410fba..3cf10e20 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -107,7 +107,10 @@ titles = { "Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.", "Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.", "Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.", - "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders." + "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.", + "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.", + "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.", + "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited." } diff --git a/javascript/ui.js b/javascript/ui.js index 77256e15..ba72623c 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -104,11 +104,6 @@ function create_tab_index_args(tabId, args){ return res } -function get_extras_tab_index(){ - const [,,...args] = [...arguments] - return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args] -} - function get_img2img_tab_index() { let res = args_to_array(arguments) res.splice(-2) diff --git a/launch.py b/launch.py index 5afb2956..82094fa0 100644 --- a/launch.py +++ b/launch.py @@ -179,7 +179,7 @@ def run_extensions_installers(settings_file): def prepare_environment(): global skip_install - torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") + torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "") @@ -187,8 +187,6 @@ def prepare_environment(): clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") - xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl') - stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') @@ -210,6 +208,7 @@ def prepare_environment(): sys.argv, _ = extract_arg(sys.argv, '-f') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') + sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch') sys.argv, update_check = extract_arg(sys.argv, '--update-check') sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests') sys.argv, skip_install = extract_arg(sys.argv, '--skip-install') @@ -221,7 +220,7 @@ def prepare_environment(): print(f"Python {sys.version}") print(f"Commit hash: {commit}") - if not is_installed("torch") or not is_installed("torchvision"): + if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") if not skip_torch_cuda_test: @@ -239,14 +238,14 @@ def prepare_environment(): if (not is_installed("xformers") or reinstall_xformers) and xformers: if platform.system() == "Windows": if platform.python_version().startswith("3.10"): - run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers") + run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers") else: print("Installation of xformers is not supported in this version of Python.") print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness") if not is_installed("xformers"): exit(0) elif platform.system() == "Linux": - run_pip("install xformers", "xformers") + run_pip("install xformers==0.0.16rc425", "xformers") if not is_installed("pyngrok") and ngrok: run_pip("install pyngrok", "ngrok") diff --git a/modules/api/api.py b/modules/api/api.py index f2e9e884..25c65e57 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -11,10 +11,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.extras import run_extras from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork @@ -23,6 +22,8 @@ from modules.sd_models import checkpoints_list, find_checkpoint_config from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import List +import piexif +import piexif.helper def upscaler_to_index(name: str): try: @@ -45,32 +46,46 @@ def validate_sampler_name(name): def setUpscalers(req: dict): reqDict = vars(req) - reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1) - reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2) - reqDict.pop('upscaler_1') - reqDict.pop('upscaler_2') + reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None) + reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None) return reqDict def decode_base64_to_image(encoding): if encoding.startswith("data:image/"): encoding = encoding.split(";")[1].split(",")[1] - return Image.open(BytesIO(base64.b64decode(encoding))) + try: + image = Image.open(BytesIO(base64.b64decode(encoding))) + return image + except Exception as err: + raise HTTPException(status_code=500, detail="Invalid encoded image") def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: - # Copy any text-only metadata - use_metadata = False - metadata = PngImagePlugin.PngInfo() - for key, value in image.info.items(): - if isinstance(key, str) and isinstance(value, str): - metadata.add_text(key, value) - use_metadata = True + if opts.samples_format.lower() == 'png': + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) + + elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): + parameters = image.info.get('parameters', None) + exif_bytes = piexif.dump({ + "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } + }) + if opts.samples_format.lower() in ("jpg", "jpeg"): + image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality) + else: + image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality) + + else: + raise HTTPException(status_code=500, detail="Invalid image format") - image.save( - output_bytes, "PNG", pnginfo=(metadata if use_metadata else None) - ) bytes_data = output_bytes.getvalue() + return base64.b64encode(bytes_data) def api_middleware(app: FastAPI): @@ -244,7 +259,7 @@ class Api: reqDict['image'] = decode_base64_to_image(reqDict['image']) with self.queue_lock: - result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) @@ -260,7 +275,7 @@ class Api: reqDict.pop('imageList') with self.queue_lock: - result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) @@ -360,13 +375,16 @@ class Api: return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers] def get_upscalers(self): - upscalers = [] - - for upscaler in shared.sd_upscalers: - u = upscaler.scaler - upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url}) - - return upscalers + return [ + { + "name": upscaler.name, + "model_name": upscaler.scaler.model_name, + "model_path": upscaler.data_path, + "model_url": None, + "scale": upscaler.scale, + } + for upscaler in shared.sd_upscalers + ] def get_sd_models(self): return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] diff --git a/modules/api/models.py b/modules/api/models.py index 1eb1fcf1..805bd8f7 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -220,6 +220,7 @@ class UpscalerItem(BaseModel): model_name: Optional[str] = Field(title="Model Name") model_path: Optional[str] = Field(title="Path") model_url: Optional[str] = Field(title="URL") + scale: Optional[float] = Field(title="Scale") class SDModelItem(BaseModel): title: str = Field(title="Title") diff --git a/modules/errors.py b/modules/errors.py index a10e8708..f6b80dbb 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -24,6 +24,18 @@ See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable """) +already_displayed = {} + + +def display_once(e: Exception, task): + if task in already_displayed: + return + + display(e, task) + + already_displayed[task] = 1 + + def run(code, task): try: code() diff --git a/modules/extras.py b/modules/extras.py index 1218f88f..36123aa5 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -1,230 +1,16 @@ -from __future__ import annotations -import math import os -import sys -import traceback +import re import shutil -import numpy as np -from PIL import Image import torch import tqdm -from typing import Callable, List, OrderedDict, Tuple -from functools import partial -from dataclasses import dataclass - -from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae -from modules.shared import opts -import modules.gfpgan_model -from modules.ui import plaintext_to_html -import modules.codeformer_model +from modules import shared, images, sd_models, sd_vae +from modules.ui_common import plaintext_to_html import gradio as gr import safetensors.torch -class LruCache(OrderedDict): - @dataclass(frozen=True) - class Key: - image_hash: int - info_hash: int - args_hash: int - - @dataclass - class Value: - image: Image.Image - info: str - - def __init__(self, max_size: int = 5, *args, **kwargs): - super().__init__(*args, **kwargs) - self._max_size = max_size - - def get(self, key: LruCache.Key) -> LruCache.Value: - ret = super().get(key) - if ret is not None: - self.move_to_end(key) # Move to end of eviction list - return ret - - def put(self, key: LruCache.Key, value: LruCache.Value) -> None: - self[key] = value - while len(self) > self._max_size: - self.popitem(last=False) - - -cached_images: LruCache = LruCache(max_size=5) - - -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): - devices.torch_gc() - - shared.state.begin() - shared.state.job = 'extras' - - imageArr = [] - # Also keep track of original file names - imageNameArr = [] - outputs = [] - - if extras_mode == 1: - #convert file to pillow image - for img in image_folder: - image = Image.open(img) - imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) - elif extras_mode == 2: - assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' - - if input_dir == '': - return outputs, "Please select an input directory.", '' - image_list = shared.listfiles(input_dir) - for img in image_list: - try: - image = Image.open(img) - except Exception: - continue - imageArr.append(image) - imageNameArr.append(img) - else: - imageArr.append(image) - imageNameArr.append(None) - - if extras_mode == 2 and output_dir != '': - outpath = output_dir - else: - outpath = opts.outdir_samples or opts.outdir_extras_samples - - # Extra operation definitions - - def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-gfpgan' - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) - - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - return (res, info) - - def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-codeformer' - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) - - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" - return (res, info) - - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - shared.state.job = 'extras-upscale' - upscaler = shared.sd_upscalers[scaler_index] - res = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) - res = cropped - return res - - def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text - nonlocal upscaling_resize - if resize_mode == 1: - upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) - crop_info = " (crop)" if upscaling_crop else "" - info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" - return (image, info) - - @dataclass - class UpscaleParams: - upscaler_idx: int - blend_alpha: float - - def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: - blended_result: Image.Image = None - image_hash: str = hash(np.array(image.getdata()).tobytes()) - for upscaler in params: - upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, - upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key(image_hash=image_hash, - info_hash=hash(info), - args_hash=hash(upscale_args)) - cached_entry = cached_images.get(cache_key) - if cached_entry is None: - res = upscale(image, *upscale_args) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" - cached_images.put(cache_key, LruCache.Value(image=res, info=info)) - else: - res, info = cached_entry.image, cached_entry.info - - if blended_result is None: - blended_result = res - else: - blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) - return (blended_result, info) - - # Build a list of operations to run - facefix_ops: List[Callable] = [] - facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] - facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] - - upscale_ops: List[Callable] = [] - upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] - - if upscaling_resize != 0: - step_params: List[UpscaleParams] = [] - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) - - upscale_ops.append(partial(run_upscalers_blend, step_params)) - - extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) - - for image, image_name in zip(imageArr, imageNameArr): - if image is None: - return outputs, "Please select an input image.", '' - - shared.state.textinfo = f'Processing image {image_name}' - - existing_pnginfo = image.info or {} - - image = image.convert("RGB") - info = "" - # Run each operation on each image - for op in extras_ops: - image, info = op(image, info) - - if opts.use_original_name_batch and image_name is not None: - basename = os.path.splitext(os.path.basename(image_name))[0] - else: - basename = '' - - if opts.enable_pnginfo: # append info before save - image.info = existing_pnginfo - image.info["extras"] = info - - if save_output: - # Add upscaler name as a suffix. - suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" - # Add second upscaler if applicable. - if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: - suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" - - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) - - if extras_mode != 2 or show_extras_results : - outputs.append(image) - - devices.torch_gc() - - return outputs, plaintext_to_html(info), '' - -def clear_cache(): - cached_images.clear() - def run_pnginfo(image): if image is None: @@ -285,7 +71,7 @@ def to_half(tensor, enable): return tensor -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae): +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): shared.state.begin() shared.state.job = 'model-merge' @@ -430,6 +216,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ for key in theta_0.keys(): theta_0[key] = to_half(theta_0[key], save_as_half) + if discard_weights: + regex = re.compile(discard_weights) + for key in list(theta_0): + if re.search(regex, key): + theta_0.pop(key, None) + ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path filename = filename_generator() if custom_name == '' else custom_name diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 80a47c79..503534e2 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -715,6 +715,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi do_not_save_samples=True, ) + p.disable_extra_networks = True + if preview_from_txt2img: p.prompt = preview_prompt p.negative_prompt = preview_negative_prompt diff --git a/modules/interrogate.py b/modules/interrogate.py index 19938cbb..236e6983 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -2,6 +2,7 @@ import os import sys import traceback from collections import namedtuple +from pathlib import Path import re import torch @@ -20,19 +21,20 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") +def category_types(): + return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')] + def download_default_clip_interrogate_categories(content_dir): print("Downloading CLIP categories...") tmpdir = content_dir + "_tmp" + category_types = ["artists", "flavors", "mediums", "movements"] + try: os.makedirs(tmpdir) - - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt")) - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt")) - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt")) - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt")) - + for category_type in category_types: + torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt")) os.rename(tmpdir, content_dir) except Exception as e: @@ -51,32 +53,38 @@ class InterrogateModels: def __init__(self, content_dir): self.loaded_categories = None + self.skip_categories = [] self.content_dir = content_dir self.running_on_cpu = devices.device_interrogate == torch.device("cpu") def categories(self): - if self.loaded_categories is not None: - return self.loaded_categories - - self.loaded_categories = [] - if not os.path.exists(self.content_dir): download_default_clip_interrogate_categories(self.content_dir) - if os.path.exists(self.content_dir): - for filename in os.listdir(self.content_dir): - m = re_topn.search(filename) - topn = 1 if m is None else int(m.group(1)) + if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories: + return self.loaded_categories - with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file: + self.loaded_categories = [] + + if os.path.exists(self.content_dir): + self.skip_categories = shared.opts.interrogate_clip_skip_categories + category_types = [] + for filename in Path(self.content_dir).glob('*.txt'): + category_types.append(filename.stem) + if filename.stem in self.skip_categories: + continue + m = re_topn.search(filename.stem) + topn = 1 if m is None else int(m.group(1)) + with open(filename, "r", encoding="utf8") as file: lines = [x.strip() for x in file.readlines()] - self.loaded_categories.append(Category(name=filename, topn=topn, items=lines)) + self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines)) return self.loaded_categories def load_blip_model(self): - import models.blip + with paths.Prioritize("BLIP"): + import models.blip files = modelloader.load_models( model_path=os.path.join(paths.models_path, "BLIP"), @@ -139,6 +147,8 @@ class InterrogateModels: def rank(self, image_features, text_array, top_count=1): import clip + devices.torch_gc() + if shared.opts.interrogate_clip_dict_limit != 0: text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)] diff --git a/modules/paths.py b/modules/paths.py index 4dd03a35..20b3e4d8 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -38,3 +38,17 @@ for d, must_exist, what, options in path_dirs: else: sys.path.append(d) paths[what] = d + + +class Prioritize: + def __init__(self, name): + self.name = name + self.path = None + + def __enter__(self): + self.path = sys.path.copy() + sys.path = [paths[self.name]] + sys.path + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.path = self.path + self.path = None diff --git a/modules/postprocessing.py b/modules/postprocessing.py new file mode 100644 index 00000000..09d8e605 --- /dev/null +++ b/modules/postprocessing.py @@ -0,0 +1,103 @@ +import os + +from PIL import Image + +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste +from modules.shared import opts + + +def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): + devices.torch_gc() + + shared.state.begin() + shared.state.job = 'extras' + + image_data = [] + image_names = [] + outputs = [] + + if extras_mode == 1: + for img in image_folder: + image = Image.open(img) + image_data.append(image) + image_names.append(os.path.splitext(img.orig_name)[0]) + elif extras_mode == 2: + assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' + assert input_dir, 'input directory not selected' + + image_list = shared.listfiles(input_dir) + for filename in image_list: + try: + image = Image.open(filename) + except Exception: + continue + image_data.append(image) + image_names.append(filename) + else: + assert image, 'image not selected' + + image_data.append(image) + image_names.append(None) + + if extras_mode == 2 and output_dir != '': + outpath = output_dir + else: + outpath = opts.outdir_samples or opts.outdir_extras_samples + + infotext = '' + + for image, name in zip(image_data, image_names): + shared.state.textinfo = name + + existing_pnginfo = image.info or {} + + pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB")) + + scripts.scripts_postproc.run(pp, args) + + if opts.use_original_name_batch and name is not None: + basename = os.path.splitext(os.path.basename(name))[0] + else: + basename = '' + + infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) + + if opts.enable_pnginfo: + pp.image.info = existing_pnginfo + pp.image.info["postprocessing"] = infotext + + if save_output: + images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) + + if extras_mode != 2 or show_extras_results: + outputs.append(pp.image) + + devices.torch_gc() + + return outputs, ui_common.plaintext_to_html(infotext), '' + + +def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): + """old handler for API""" + + args = scripts.scripts_postproc.create_args_for_run({ + "Upscale": { + "upscale_mode": resize_mode, + "upscale_by": upscaling_resize, + "upscale_to_width": upscaling_resize_w, + "upscale_to_height": upscaling_resize_h, + "upscale_crop": upscaling_crop, + "upscaler_1_name": extras_upscaler_1, + "upscaler_2_name": extras_upscaler_2, + "upscaler_2_visibility": extras_upscaler_2_visibility, + }, + "GFPGAN": { + "gfpgan_visibility": gfpgan_visibility, + }, + "CodeFormer": { + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + }, + }) + + return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) diff --git a/modules/processing.py b/modules/processing.py index 01c1b53c..a1138491 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -140,6 +140,7 @@ class StableDiffusionProcessing: self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} self.override_settings_restore_afterwards = override_settings_restore_afterwards self.is_using_inpainting_conditioning = False + self.disable_extra_networks = False if not seed_enable_extras: self.subseed = -1 @@ -579,7 +580,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) - extra_networks.activate(p, extra_network_data) + if not p.disable_extra_networks: + extra_networks.activate(p, extra_network_data) with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: processed = Processed(p, [], p.seed, "") @@ -723,7 +725,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) - extra_networks.deactivate(p, extra_network_data) + if not p.disable_extra_networks: + extra_networks.deactivate(p, extra_network_data) + devices.torch_gc() res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) diff --git a/modules/scripts.py b/modules/scripts.py index 4ffc369b..03907a63 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -7,7 +7,7 @@ from collections import namedtuple import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared, paths, script_callbacks, extensions, script_loading +from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing AlwaysVisible = object() @@ -150,8 +150,10 @@ def basedir(): return current_basedir -scripts_data = [] ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"]) + +scripts_data = [] +postprocessing_scripts_data = [] ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) @@ -190,23 +192,31 @@ def list_files_with_name(filename): def load_scripts(): global current_basedir scripts_data.clear() + postprocessing_scripts_data.clear() script_callbacks.clear_callbacks() scripts_list = list_scripts("scripts", ".py") syspath = sys.path + def register_scripts_from_module(module): + for key, script_class in module.__dict__.items(): + if type(script_class) != type: + continue + + if issubclass(script_class, Script): + scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) + elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): + postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) + for scriptfile in sorted(scripts_list): try: if scriptfile.basedir != paths.script_path: sys.path = [scriptfile.basedir] + sys.path current_basedir = scriptfile.basedir - module = script_loading.load_module(scriptfile.path) - - for key, script_class in module.__dict__.items(): - if type(script_class) == type and issubclass(script_class, Script): - scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) + script_module = script_loading.load_module(scriptfile.path) + register_scripts_from_module(script_module) except Exception: print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) @@ -413,6 +423,7 @@ class ScriptRunner: scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() +scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner() scripts_current: ScriptRunner = None @@ -423,12 +434,13 @@ def reload_script_body_only(): def reload_scripts(): - global scripts_txt2img, scripts_img2img + global scripts_txt2img, scripts_img2img, scripts_postproc load_scripts() scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner() def IOComponent_init(self, *args, **kwargs): diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py new file mode 100644 index 00000000..25de02d0 --- /dev/null +++ b/modules/scripts_postprocessing.py @@ -0,0 +1,147 @@ +import os +import gradio as gr + +from modules import errors, shared + + +class PostprocessedImage: + def __init__(self, image): + self.image = image + self.info = {} + + +class ScriptPostprocessing: + filename = None + controls = None + args_from = None + args_to = None + + order = 1000 + """scripts will be ordred by this value in postprocessing UI""" + + name = None + """this function should return the title of the script.""" + + group = None + """A gr.Group component that has all script's UI inside it""" + + def ui(self): + """ + This function should create gradio UI elements. See https://gradio.app/docs/#components + The return value should be a dictionary that maps parameter names to components used in processing. + Values of those components will be passed to process() function. + """ + + pass + + def process(self, pp: PostprocessedImage, **args): + """ + This function is called to postprocess the image. + args contains a dictionary with all values returned by components from ui() + """ + + pass + + def image_changed(self): + pass + + +def wrap_call(func, filename, funcname, *args, default=None, **kwargs): + try: + res = func(*args, **kwargs) + return res + except Exception as e: + errors.display(e, f"calling {filename}/{funcname}") + + return default + + +class ScriptPostprocessingRunner: + def __init__(self): + self.scripts = None + self.ui_created = False + + def initialize_scripts(self, scripts_data): + self.scripts = [] + + for script_class, path, basedir, script_module in scripts_data: + script: ScriptPostprocessing = script_class() + script.filename = path + + self.scripts.append(script) + + def create_script_ui(self, script, inputs): + script.args_from = len(inputs) + script.args_to = len(inputs) + + script.controls = wrap_call(script.ui, script.filename, "ui") + + for control in script.controls.values(): + control.custom_script_source = os.path.basename(script.filename) + + inputs += list(script.controls.values()) + script.args_to = len(inputs) + + def scripts_in_preferred_order(self): + if self.scripts is None: + import modules.scripts + self.initialize_scripts(modules.scripts.postprocessing_scripts_data) + + scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")] + + def script_score(name): + name = name.lower() + for i, possible_match in enumerate(scripts_order): + if possible_match in name: + return i + + return len(self.scripts) + + script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)} + + return sorted(self.scripts, key=lambda x: script_scores[x.name]) + + def setup_ui(self): + inputs = [] + + for script in self.scripts_in_preferred_order(): + with gr.Box() as group: + self.create_script_ui(script, inputs) + + script.group = group + + self.ui_created = True + return inputs + + def run(self, pp: PostprocessedImage, args): + for script in self.scripts_in_preferred_order(): + shared.state.job = script.name + + script_args = args[script.args_from:script.args_to] + + process_args = {} + for (name, component), value in zip(script.controls.items(), script_args): + process_args[name] = value + + script.process(pp, **process_args) + + def create_args_for_run(self, scripts_args): + if not self.ui_created: + with gr.Blocks(analytics_enabled=False): + self.setup_ui() + + scripts = self.scripts_in_preferred_order() + args = [None] * max([x.args_to for x in scripts]) + + for script in scripts: + script_args_dict = scripts_args.get(script.name, None) + if script_args_dict is not None: + + for i, name in enumerate(script.controls): + args[script.args_from + i] = script_args_dict.get(name, None) + + return args + + def image_changed(self): + for script in self.scripts_in_preferred_order(): + script.image_changed() diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index c72d8efc..e90aa9fe 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -41,7 +41,9 @@ class DisableInitialization: return self.create_model_and_transforms(*args, pretrained=None, **kwargs) def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): - return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) + res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) + res.name_or_path = pretrained_model_name_or_path + return res def transformers_modeling_utils_load_pretrained_model(*args, **kwargs): args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 4fa54329..74452709 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,7 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared +from modules import shared, errors from modules.hypernetworks import hypernetwork from .sub_quadratic_attention import efficient_dot_product_attention @@ -279,6 +279,21 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ ) +def get_xformers_flash_attention_op(q, k, v): + if not shared.cmd_opts.xformers_flash_attention: + return None + + try: + flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = flash_attention_op + if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): + return flash_attention_op + except Exception as e: + errors.display_once(e, "enabling flash attention") + + return None + + def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) @@ -290,7 +305,8 @@ def xformers_attention_forward(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) @@ -365,7 +381,7 @@ def xformers_attnblock_forward(self, x): q = q.contiguous() k = k.contiguous() v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v) + out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out diff --git a/modules/shared.py b/modules/shared.py index 52bbb807..5f713bee 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -57,6 +57,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None) parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") +parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") @@ -398,7 +399,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }), + "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}), @@ -406,7 +407,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), - 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), + "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), + "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), })) options_templates.update(options_section(('compatibility', "Compatibility"), { @@ -422,6 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), + "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), @@ -429,6 +432,10 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"), })) +options_templates.update(options_section(('extra_networks', "Extra Networks"), { + "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, { "choices": ["cards", "thumbs"] }), +})) + options_templates.update(options_section(('ui', "User interface"), { "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), @@ -443,9 +450,12 @@ options_templates.update(options_section(('ui', "User interface"), { "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"), - 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), - 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), - 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), + "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"), + "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), + "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"), + "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) options_templates.update(options_section(('ui', "Live previews"), { @@ -470,6 +480,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), })) +options_templates.update(options_section(('postprocessing', "Postprocessing"), { + 'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"), + 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), +})) + options_templates.update(options_section((None, "Hidden options"), { "disabled_extensions": OptionInfo([], "Disable those extensions"), "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), diff --git a/modules/ui.py b/modules/ui.py index 64180889..e701393c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -5,7 +5,6 @@ import mimetypes import os import platform import random -import subprocess as sp import sys import tempfile import time @@ -20,7 +19,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -41,6 +40,7 @@ from modules.sd_samplers import samplers, samplers_for_img2img from modules.textual_inversion import textual_inversion import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text +import modules.extras warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) @@ -75,6 +75,7 @@ css_hide_progressbar = """ .wrap .m-12::before { content:"Loading..." } .wrap .z-20 svg { display:none!important; } .wrap .z-20::before { content:"Loading..." } +.wrap.cover-bg .z-20::before { content:"" } .progress-bar { display:none!important; } .meta-text { display:none!important; } .meta-text-center { display:none!important; } @@ -85,7 +86,6 @@ css_hide_progressbar = """ random_symbol = '\U0001f3b2\ufe0f' # 🎲️ reuse_symbol = '\u267b\ufe0f' # ♻️ paste_symbol = '\u2199\ufe0f' # ↙ -folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 @@ -94,78 +94,14 @@ extra_networks_symbol = '\U0001F3B4' # 🎴 def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text + return ui_common.plaintext_to_html(text) + def send_gradio_gallery_to_image(x): if len(x) == 0: return None return image_from_url_text(x[0]) -def save_files(js_data, images, do_make_zip, index): - import csv - filenames = [] - fullfns = [] - - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it - class MyObject: - def __init__(self, d=None): - if d is not None: - for key, value in d.items(): - setattr(self, key, value) - - data = json.loads(js_data) - - p = MyObject(data) - path = opts.outdir_save - save_to_dirs = opts.use_save_to_dirs_for_ui - extension: str = opts.samples_format - start_index = 0 - - if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - - images = [images[index]] - start_index = index - - os.makedirs(opts.outdir_save, exist_ok=True) - - with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - for image_index, filedata in enumerate(images, start_index): - image = image_from_url_text(filedata) - - is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) - - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) - - filename = os.path.relpath(fullfn, path) - filenames.append(filename) - fullfns.append(fullfn) - if txt_fullfn: - filenames.append(os.path.basename(txt_fullfn)) - fullfns.append(txt_fullfn) - - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - - # Make Zip - if do_make_zip: - zip_filepath = os.path.join(path, "images.zip") - - from zipfile import ZipFile - with ZipFile(zip_filepath, "w") as zip_file: - for i in range(len(fullfns)): - with open(fullfns[i], mode="rb") as f: - zip_file.writestr(filenames[i], f.read()) - fullfns.insert(0, zip_filepath) - - return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") - - def visit(x, func, path=""): if hasattr(x, 'children'): for c in x.children: @@ -443,19 +379,6 @@ def apply_setting(key, value): opts.save(shared.config_filename) return getattr(opts, key) - -def update_generation_info(generation_info, html_info, img_index): - try: - generation_info = json.loads(generation_info) - if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info, gr.update() - return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update() - except Exception: - pass - # if the json parse or anything else fails, just return the old html_info - return html_info, gr.update() - - def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def refresh(): refresh_method() @@ -476,107 +399,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele def create_output_panel(tabname, outdir): - def open_folder(f): - if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') - return - elif not os.path.isdir(f): - print(f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""", file=sys.stderr) - return - - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - elif "microsoft-standard-WSL2" in platform.uname().release: - sp.Popen(["wsl-open", path]) - else: - sp.Popen(["xdg-open", path]) - - with gr.Column(variant='panel', elem_id=f"{tabname}_results"): - with gr.Group(elem_id=f"{tabname}_gallery_container"): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) - - generation_info = None - with gr.Column(): - with gr.Row(elem_id=f"image_buttons_{tabname}"): - open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') - - if tabname != "extras": - save = gr.Button('Save', elem_id=f'save_{tabname}') - save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') - - buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) - - open_folder_button.click( - fn=lambda: open_folder(opts.outdir_samples or outdir), - inputs=[], - outputs=[], - ) - - if tabname != "extras": - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') - - with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') - if tabname == 'txt2img' or tabname == 'img2img': - generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") - generation_info_button.click( - fn=update_generation_info, - _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", - inputs=[generation_info, html_info, html_info], - outputs=[html_info, html_info], - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ], - show_progress=False, - ) - - save_zip.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + return ui_common.create_output_panel(tabname, outdir) def create_sampler_and_steps_selection(choices, tabname): @@ -935,7 +758,7 @@ def create_ui(): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') elif category == "checkboxes": - with FormRow(elem_id="img2img_checkboxes"): + with FormRow(elem_id="img2img_checkboxes", variant="compact"): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") @@ -1122,86 +945,7 @@ def create_ui(): modules.scripts.scripts_current = None with gr.Blocks(analytics_enabled=False) as extras_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='compact'): - with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image', elem_id="extras_single_tab"): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") - - with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") - - with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") - show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") - - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): - with gr.Group(): - with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - - with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") - - with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") - - with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") - - with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") - - result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) - - submit.click( - fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), - _js="get_extras_tab_index", - inputs=[ - dummy_component, - dummy_component, - extras_image, - image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, - gfpgan_visibility, - codeformer_visibility, - codeformer_weight, - upscaling_resize, - upscaling_resize_w, - upscaling_resize_h, - upscaling_crop, - extras_upscaler_1, - extras_upscaler_2, - extras_upscaler_2_visibility, - upscale_before_face_fix, - ], - outputs=[ - result_images, - html_info_x, - html_info, - ] - ) - parameters_copypaste.add_paste_fields("extras", extras_image, None) - - extras_image.change( - fn=modules.extras.clear_cache, - inputs=[], outputs=[] - ) + ui_postprocessing.create_ui() with gr.Blocks(analytics_enabled=False) as pnginfo_interface: with gr.Row().style(equal_height=False): @@ -1222,10 +966,19 @@ def create_ui(): outputs=[html, generation_info, html2], ) + def update_interp_description(value): + interp_description_css = "

{}

" + interp_descriptions = { + "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."), + "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"), + "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M") + } + return interp_descriptions[value] + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='compact'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") + interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description") with FormRow(elem_id="modelmerger_models"): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") @@ -1240,6 +993,7 @@ def create_ui(): custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") + interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description]) with FormRow(): checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") @@ -1254,6 +1008,9 @@ def create_ui(): bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae") create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae") + with FormRow(): + discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights") + with gr.Row(): modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') @@ -1265,7 +1022,7 @@ def create_ui(): with gr.Row().style(equal_height=False): gr.HTML(value="

See wiki for detailed explanation.

") - with gr.Row().style(equal_height=False): + with gr.Row(variant="compact").style(equal_height=False): with gr.Tabs(elem_id="train_tabs"): with gr.Tab(label="Create embedding"): @@ -1844,6 +1601,7 @@ def create_ui(): checkpoint_format, config_source, bake_in_vae, + discard_weights, ], outputs=[ primary_model_name, @@ -1931,28 +1689,27 @@ def create_ui(): with open(ui_config_file, "w", encoding="utf8") as file: json.dump(ui_settings, file, indent=4) + # Required as a workaround for change() event not triggering when loading values from ui-config.json + interp_description.value = update_interp_description(interp_method.value) + return demo def reload_javascript(): - with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - javascript = f'' - - scripts_list = modules.scripts.list_scripts("javascript", ".js") - - for basedir, filename, path in scripts_list: - with open(path, "r", encoding="utf8") as jsfile: - javascript += f"\n" + head = f'\n' + inline = f"{localization.localization_js(shared.opts.localization)};" if cmd_opts.theme is not None: - javascript += f"\n\n" + inline += f"set_theme('{cmd_opts.theme}');" - javascript += f"\n" + for script in modules.scripts.list_scripts("javascript", ".js"): + head += f'\n' + + head += f'\n' def template_response(*args, **kwargs): res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - res.body = res.body.replace( - b'', f'{javascript}'.encode("utf8")) + res.body = res.body.replace(b'', f'{head}'.encode("utf8")) res.init_headers() return res diff --git a/modules/ui_common.py b/modules/ui_common.py new file mode 100644 index 00000000..9405ac1f --- /dev/null +++ b/modules/ui_common.py @@ -0,0 +1,202 @@ +import json +import html +import os +import platform +import sys + +import gradio as gr +import subprocess as sp + +from modules import call_queue, shared +from modules.generation_parameters_copypaste import image_from_url_text +import modules.images + +folder_symbol = '\U0001f4c2' # 📂 + + +def update_generation_info(generation_info, html_info, img_index): + try: + generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info, gr.update() + return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update() + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info, gr.update() + + +def plaintext_to_html(text): + text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" + return text + + +def save_files(js_data, images, do_make_zip, index): + import csv + filenames = [] + fullfns = [] + + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + + data = json.loads(js_data) + + p = MyObject(data) + path = shared.opts.outdir_save + save_to_dirs = shared.opts.use_save_to_dirs_for_ui + extension: str = shared.opts.samples_format + start_index = 0 + + if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only + + images = [images[index]] + start_index = index + + os.makedirs(shared.opts.outdir_save, exist_ok=True) + + with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + + for image_index, filedata in enumerate(images, start_index): + image = image_from_url_text(filedata) + + is_grid = image_index < p.index_of_first_image + i = 0 if is_grid else (image_index - p.index_of_first_image) + + fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + filename = os.path.relpath(fullfn, path) + filenames.append(filename) + fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) + + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + + # Make Zip + if do_make_zip: + zip_filepath = os.path.join(path, "images.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") + + +def create_output_panel(tabname, outdir): + from modules import shared + import modules.generation_parameters_copypaste as parameters_copypaste + + def open_folder(f): + if not os.path.exists(f): + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + return + elif not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) + else: + sp.Popen(["xdg-open", path]) + + with gr.Column(variant='panel', elem_id=f"{tabname}_results"): + with gr.Group(elem_id=f"{tabname}_gallery_container"): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) + + generation_info = None + with gr.Column(): + with gr.Row(elem_id=f"image_buttons_{tabname}"): + open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') + + if tabname != "extras": + save = gr.Button('Save', elem_id=f'save_{tabname}') + save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') + + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + + open_folder_button.click( + fn=lambda: open_folder(shared.opts.outdir_samples or outdir), + inputs=[], + outputs=[], + ) + + if tabname != "extras": + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') + + with gr.Group(): + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", + inputs=[generation_info, html_info, html_info], + outputs=[html_info, html_info], + ) + + save.click( + fn=call_queue.wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ], + show_progress=False, + ) + + save_zip.click( + fn=call_queue.wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + else: + html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log diff --git a/modules/ui_components.py b/modules/ui_components.py index 46324425..9aec3097 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -47,3 +47,4 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): def get_block_name(self): return "colorpicker" + diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index e2e060c8..8b4f97f8 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -3,6 +3,7 @@ import os.path from modules import shared import gradio as gr import json +import html from modules.generation_parameters_copypaste import image_from_url_text @@ -26,6 +27,7 @@ class ExtraNetworksPage: pass def create_html(self, tabname): + view = shared.opts.extra_networks_default_view items_html = '' for item in self.list_items(): @@ -36,7 +38,7 @@ class ExtraNetworksPage: items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) res = f""" -
+
{items_html}
""" @@ -53,12 +55,13 @@ class ExtraNetworksPage: preview = item.get("preview", None) args = { - "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '', - "prompt": json.dumps(item["prompt"]), + "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '', + "prompt": item["prompt"], "tabname": json.dumps(tabname), "local_preview": json.dumps(item["local_preview"]), "name": item["name"], - "allow_negative_prompt": "true" if self.allow_negative_prompt else "false", + "card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"', + "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', } return self.card_page.format(**args) @@ -79,10 +82,26 @@ class ExtraNetworksUi: self.tabname = None +def pages_in_preferred_order(pages): + tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")] + + def tab_name_score(name): + name = name.lower() + for i, possible_match in enumerate(tab_order): + if possible_match in name: + return i + + return len(pages) + + tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)} + + return sorted(pages, key=lambda x: tab_scores[x.name]) + + def create_ui(container, button, tabname): ui = ExtraNetworksUi() ui.pages = [] - ui.stored_extra_pages = extra_pages.copy() + ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy()) ui.tabname = tabname with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 312dbaf0..65d000cf 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -1,3 +1,4 @@ +import json import os from modules import shared, ui_extra_networks @@ -25,7 +26,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": preview, - "prompt": f"", + "prompt": json.dumps(f""), "local_preview": path + ".png", } diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index e4a6e3bf..dbd23d2d 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -1,3 +1,4 @@ +import json import os from modules import ui_extra_networks, sd_hijack @@ -24,7 +25,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): "name": embedding.name, "filename": embedding.filename, "preview": preview, - "prompt": embedding.name, + "prompt": json.dumps(embedding.name), "local_preview": path + ".preview.png", } diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py new file mode 100644 index 00000000..b418d955 --- /dev/null +++ b/modules/ui_postprocessing.py @@ -0,0 +1,57 @@ +import gradio as gr +from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue +import modules.generation_parameters_copypaste as parameters_copypaste + + +def create_ui(): + tab_index = gr.State(value=0) + + with gr.Row().style(equal_height=False, variant='compact'): + with gr.Column(variant='compact'): + with gr.Tabs(elem_id="mode_extras"): + with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single: + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") + + with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch: + image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") + + with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir: + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") + show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") + + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + + script_inputs = scripts.scripts_postproc.setup_ui() + + with gr.Column(): + result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) + + tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) + tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index]) + tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) + + submit.click( + fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), + inputs=[ + tab_index, + extras_image, + image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, + *script_inputs + ], + outputs=[ + result_images, + html_info_x, + html_info, + ] + ) + + parameters_copypaste.add_paste_fields("extras", extras_image, None) + + extras_image.change( + fn=scripts.scripts_postproc.image_changed, + inputs=[], outputs=[] + ) diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py new file mode 100644 index 00000000..a7d80d40 --- /dev/null +++ b/scripts/postprocessing_codeformer.py @@ -0,0 +1,36 @@ +from PIL import Image +import numpy as np + +from modules import scripts_postprocessing, codeformer_model +import gradio as gr + +from modules.ui_components import FormRow + + +class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): + name = "CodeFormer" + order = 3000 + + def ui(self): + with FormRow(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight") + + return { + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): + if codeformer_visibility == 0: + return + + restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) + res = Image.fromarray(restored_img) + + if codeformer_visibility < 1.0: + res = Image.blend(pp.image, res, codeformer_visibility) + + pp.image = res + pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3) + pp.info["CodeFormer weight"] = round(codeformer_weight, 3) diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py new file mode 100644 index 00000000..d854f3f7 --- /dev/null +++ b/scripts/postprocessing_gfpgan.py @@ -0,0 +1,33 @@ +from PIL import Image +import numpy as np + +from modules import scripts_postprocessing, gfpgan_model +import gradio as gr + +from modules.ui_components import FormRow + + +class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): + name = "GFPGAN" + order = 2000 + + def ui(self): + with FormRow(): + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility") + + return { + "gfpgan_visibility": gfpgan_visibility, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): + if gfpgan_visibility == 0: + return + + restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) + res = Image.fromarray(restored_img) + + if gfpgan_visibility < 1.0: + res = Image.blend(pp.image, res, gfpgan_visibility) + + pp.image = res + pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3) diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py new file mode 100644 index 00000000..095d29b2 --- /dev/null +++ b/scripts/postprocessing_upscale.py @@ -0,0 +1,106 @@ +from PIL import Image +import numpy as np + +from modules import scripts_postprocessing, shared +import gradio as gr + +from modules.ui_components import FormRow + + +upscale_cache = {} + + +class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): + name = "Upscale" + order = 1000 + + def ui(self): + selected_tab = gr.State(value=0) + + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: + with FormRow(): + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") + + with FormRow(): + extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + + with FormRow(): + extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") + + tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) + tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) + + return { + "upscale_mode": selected_tab, + "upscale_by": upscaling_resize, + "upscale_to_width": upscaling_resize_w, + "upscale_to_height": upscaling_resize_h, + "upscale_crop": upscaling_crop, + "upscaler_1_name": extras_upscaler_1, + "upscaler_2_name": extras_upscaler_2, + "upscaler_2_visibility": extras_upscaler_2_visibility, + } + + def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop): + if upscale_mode == 1: + upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height) + info["Postprocess upscale to"] = f"{upscale_to_width}x{upscale_to_height}" + else: + info["Postprocess upscale by"] = upscale_by + + cache_key = (hash(np.array(image.getdata()).tobytes()), upscaler.name, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + cached_image = upscale_cache.pop(cache_key, None) + + if cached_image is not None: + image = cached_image + else: + image = upscaler.scaler.upscale(image, upscale_by, upscaler.data_path) + + upscale_cache[cache_key] = image + if len(upscale_cache) > shared.opts.upscaling_max_images_in_cache: + upscale_cache.pop(next(iter(upscale_cache), None), None) + + if upscale_mode == 1 and upscale_crop: + cropped = Image.new("RGB", (upscale_to_width, upscale_to_height)) + cropped.paste(image, box=(upscale_to_width // 2 - image.width // 2, upscale_to_height // 2 - image.height // 2)) + image = cropped + info["Postprocess crop to"] = f"{image.width}x{image.height}" + + return image + + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): + if upscaler_1_name == "None": + upscaler_1_name = None + + upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_1_name]), None) + assert upscaler1 or (upscaler_1_name is None), f'could not find upscaler named {upscaler_1_name}' + + if not upscaler1: + return + + if upscaler_2_name == "None": + upscaler_2_name = None + + upscaler2 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_2_name and x.name != "None"]), None) + assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}' + + upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + pp.info[f"Postprocess upscaler"] = upscaler1.name + + if upscaler2 and upscaler_2_visibility > 0: + second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility) + + pp.info[f"Postprocess upscaler 2"] = upscaler2.name + + pp.image = upscaled_image + + def image_changed(self): + upscale_cache.clear() diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index b1badec9..1a452355 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -165,10 +165,14 @@ class AxisOption: self.confirm = confirm self.cost = cost self.choices = choices - self.is_img2img = False class AxisOptionImg2Img(AxisOption): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_img2img = True + +class AxisOptionTxt2Img(AxisOption): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.is_img2img = False @@ -180,10 +184,12 @@ axis_options = [ AxisOption("Var. seed", int, apply_field("subseed")), AxisOption("Var. strength", float, apply_field("subseed_strength")), AxisOption("Steps", int, apply_field("steps")), + AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")), AxisOption("CFG Scale", float, apply_field("cfg_scale")), AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), - AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), AxisOption("Sigma Churn", float, apply_field("s_churn")), AxisOption("Sigma min", float, apply_field("s_tmin")), @@ -192,8 +198,8 @@ axis_options = [ AxisOption("Eta", float, apply_field("eta")), AxisOption("Clip skip", int, apply_clip_skip), AxisOption("Denoising", float, apply_field("denoising_strength")), - AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [x.name for x in shared.sd_upscalers]), - AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), + AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]), + AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), ] @@ -288,42 +294,41 @@ class Script(scripts.Script): return "X/Y plot" def ui(self, is_img2img): - current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img and is_img2img] + self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] with gr.Row(): with gr.Column(scale=19): with gr.Row(): - x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) + x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False) with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False) - with gr.Row(variant="compact"): + with gr.Row(variant="compact", elem_id="axis_options"): draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button") def swap_axes(x_type, x_values, y_type, y_values): - nonlocal current_axis_options - return current_axis_options[y_type].label, y_values, current_axis_options[x_type].label, x_values + return self.current_axis_options[y_type].label, y_values, self.current_axis_options[x_type].label, x_values swap_args = [x_type, x_values, y_type, y_values] swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) def fill(x_type): - axis = axis_options[x_type] + axis = self.current_axis_options[x_type] return ", ".join(axis.choices()) if axis.choices else gr.update() fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values]) fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values]) def select_axis(x_type): - return gr.Button.update(visible=axis_options[x_type].choices is not None) + return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None) x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button]) y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) @@ -398,10 +403,10 @@ class Script(scripts.Script): return valslist - x_opt = axis_options[x_type] + x_opt = self.current_axis_options[x_type] xs = process_axis(x_opt, x_values) - y_opt = axis_options[y_type] + y_opt = self.current_axis_options[y_type] ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): @@ -422,10 +427,21 @@ class Script(scripts.Script): total_steps = p.steps * len(xs) * len(ys) if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: - total_steps *= 2 + if x_opt.label == "Hires steps": + total_steps += sum(xs) * len(ys) + elif y_opt.label == "Hires steps": + total_steps += sum(ys) * len(xs) + elif p.hr_second_pass_steps: + total_steps += p.hr_second_pass_steps * len(xs) * len(ys) + else: + total_steps *= 2 - print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})") - shared.total_tqdm.updateTotal(total_steps * p.n_iter) + total_steps *= p.n_iter + + image_cell_count = p.n_iter * p.batch_size + cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else "" + print(f"X/Y plot will create {len(xs) * len(ys) * image_cell_count} images on a {len(xs)}x{len(ys)} grid{cell_console_text}. (Total steps to process: {total_steps})") + shared.total_tqdm.updateTotal(total_steps) grid_infotext = [None] diff --git a/style.css b/style.css index 1e59575f..b2677fa1 100644 --- a/style.css +++ b/style.css @@ -589,7 +589,7 @@ canvas[key="mask"] { /* Extensions */ -#tab_extensions table``{ +#tab_extensions table{ border-collapse: collapse; } @@ -707,12 +707,24 @@ footer { #txt2img_checkboxes, #img2img_checkboxes{ margin-bottom: 0.5em; + margin-left: 0em; } #txt2img_checkboxes > div, #img2img_checkboxes > div{ flex: 0; white-space: nowrap; min-width: auto; } +#txt2img_hires_fix{ + margin-left: -0.8em; +} + +#img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{ + margin-left: 0em; +} + +#axis_options { + margin-left: 0em; +} .inactive{ opacity: 0.5; @@ -780,21 +792,78 @@ footer { display: inline-block; max-width: 16em; margin: 0.3em; + align-self: center; } -.extra-network-cards .nocards{ +#txt2img_extra_view, #img2img_extra_view { + width: auto; +} + +.extra-network-cards .nocards, .extra-network-thumbs .nocards{ margin: 1.25em 0.5em 0.5em 0.5em; } -.extra-network-cards .nocards h1{ +.extra-network-cards .nocards h1, .extra-network-thumbs .nocards h1{ font-size: 1.5em; margin-bottom: 1em; } -.extra-network-cards .nocards li{ +.extra-network-cards .nocards li, .extra-network-thumbs .nocards li{ margin-left: 0.5em; } +.extra-network-thumbs { + display: flex; + flex-flow: row wrap; + gap: 10px; +} + +.extra-network-thumbs .card { + height: 6em; + width: 6em; + cursor: pointer; + background-image: url('./file=html/card-no-preview.png'); + background-size: cover; + background-position: center center; + position: relative; +} + +.extra-network-thumbs .card:hover .additional a { + display: block; +} + +.extra-network-thumbs .actions .additional a { + background-image: url('./file=html/image-update.svg'); + background-repeat: no-repeat; + background-size: cover; + background-position: center center; + position: absolute; + top: 0; + left: 0; + width: 24px; + height: 24px; + display: none; + font-size: 0; + text-align: -9999; +} + +.extra-network-thumbs .actions .name { + position: absolute; + bottom: 0; + font-size: 10px; + padding: 3px; + width: 100%; + overflow: hidden; + white-space: nowrap; + text-overflow: ellipsis; + background: rgba(0,0,0,.5); +} + +.extra-network-thumbs .card:hover .actions .name { + white-space: normal; + word-break: break-all; +} + .extra-network-cards .card{ display: inline-block; margin: 0.5em; diff --git a/webui.py b/webui.py index d235da74..e1565a8d 100644 --- a/webui.py +++ b/webui.py @@ -1,6 +1,5 @@ import os import sys -import threading import time import importlib import signal @@ -8,6 +7,10 @@ import re from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware +from packaging import version + +import logging +logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) from modules import import_hook, errors, extra_networks from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion @@ -22,7 +25,6 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__: from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks import modules.codeformer_model as codeformer -import modules.extras import modules.face_restoration import modules.gfpgan_model as gfpgan import modules.img2img @@ -50,7 +52,32 @@ else: server_name = "0.0.0.0" if cmd_opts.listen else None +def check_versions(): + expected_torch_version = "1.13.1" + + if version.parse(torch.__version__) < version.parse(expected_torch_version): + errors.print_error_explanation(f""" +You are running torch {torch.__version__}. +The program is tested to work with torch {expected_torch_version}. +To reinstall the desired version, run with commandline flag --reinstall-torch. +Beware that this will cause a lot of large files to be downloaded. + """.strip()) + + expected_xformers_version = "0.0.16rc425" + if shared.xformers_available: + import xformers + + if version.parse(xformers.__version__) < version.parse(expected_xformers_version): + errors.print_error_explanation(f""" +You are running xformers {xformers.__version__}. +The program is tested to work with xformers {expected_xformers_version}. +To reinstall the desired version, run with commandline flag --reinstall-xformers. + """.strip()) + + def initialize(): + check_versions() + extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) diff --git a/webui.sh b/webui.sh index 1edf921d..8cdad22d 100755 --- a/webui.sh +++ b/webui.sh @@ -104,6 +104,23 @@ then fi # Check prerequisites +gpu_info=$(lspci 2>/dev/null | grep VGA) +case "$gpu_info" in + *"Navi 1"*|*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 + ;; + *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 + printf "\n%s\n" "${delimiter}" + printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half" + printf "\n%s\n" "${delimiter}" + ;; + *) + ;; +esac +if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] +then + export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" +fi + for preq in "${GIT}" "${python_cmd}" do if ! hash "${preq}" &>/dev/null @@ -164,16 +181,6 @@ then else printf "\n%s\n" "${delimiter}" printf "Launching launch.py..." - printf "\n%s\n" "${delimiter}" - gpu_info=$(lspci 2>/dev/null | grep VGA) - if echo "$gpu_info" | grep -q "AMD" - then - if [[ -z "${TORCH_COMMAND}" ]] - then - export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" - fi - HSA_OVERRIDE_GFX_VERSION=10.3.0 exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" - else - exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" - fi + printf "\n%s\n" "${delimiter}" + exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" fi