diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index da1797dc..137e58f7 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -57,6 +57,7 @@ class LoraUpDownModule: def __init__(self): self.up = None self.down = None + self.alpha = None def assign_lora_names_to_compvis_modules(sd_model): @@ -92,6 +93,15 @@ def load_lora(name, filename): keys_failed_to_match.append(key_diffusers) continue + lora_module = lora.modules.get(key, None) + if lora_module is None: + lora_module = LoraUpDownModule() + lora.modules[key] = lora_module + + if lora_key == "alpha": + lora_module.alpha = weight.item() + continue + if type(sd_module) == torch.nn.Linear: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.Conv2d: @@ -104,17 +114,12 @@ def load_lora(name, filename): module.to(device=devices.device, dtype=devices.dtype) - lora_module = lora.modules.get(key, None) - if lora_module is None: - lora_module = LoraUpDownModule() - lora.modules[key] = lora_module - if lora_key == "lora_up.weight": lora_module.up = module elif lora_key == "lora_down.weight": lora_module.down = module else: - assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight' + assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha' if len(keys_failed_to_match) > 0: print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") @@ -161,7 +166,7 @@ def lora_forward(module, input, res): for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is not None: - res = res + module.up(module.down(input)) * lora.multiplier + res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) return res diff --git a/html/image-update.svg b/html/image-update.svg new file mode 100644 index 00000000..3abf12df --- /dev/null +++ b/html/image-update.svg @@ -0,0 +1,7 @@ + diff --git a/launch.py b/launch.py index 5afb2956..e7a0b50c 100644 --- a/launch.py +++ b/launch.py @@ -179,7 +179,7 @@ def run_extensions_installers(settings_file): def prepare_environment(): global skip_install - torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") + torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "") @@ -187,8 +187,6 @@ def prepare_environment(): clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") - xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl') - stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') @@ -210,6 +208,7 @@ def prepare_environment(): sys.argv, _ = extract_arg(sys.argv, '-f') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') + sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch') sys.argv, update_check = extract_arg(sys.argv, '--update-check') sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests') sys.argv, skip_install = extract_arg(sys.argv, '--skip-install') @@ -221,7 +220,7 @@ def prepare_environment(): print(f"Python {sys.version}") print(f"Commit hash: {commit}") - if not is_installed("torch") or not is_installed("torchvision"): + if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") if not skip_torch_cuda_test: @@ -239,7 +238,7 @@ def prepare_environment(): if (not is_installed("xformers") or reinstall_xformers) and xformers: if platform.system() == "Windows": if platform.python_version().startswith("3.10"): - run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers") + run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers") else: print("Installation of xformers is not supported in this version of Python.") print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness") diff --git a/modules/api/api.py b/modules/api/api.py index 5d60fc0a..b1dd14cc 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -22,6 +22,8 @@ from modules.sd_models import checkpoints_list, find_checkpoint_config from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import List +import piexif +import piexif.helper def upscaler_to_index(name: str): try: @@ -56,18 +58,30 @@ def decode_base64_to_image(encoding): def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: - # Copy any text-only metadata - use_metadata = False - metadata = PngImagePlugin.PngInfo() - for key, value in image.info.items(): - if isinstance(key, str) and isinstance(value, str): - metadata.add_text(key, value) - use_metadata = True + if opts.samples_format.lower() == 'png': + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) + + elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): + parameters = image.info.get('parameters', None) + exif_bytes = piexif.dump({ + "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } + }) + if opts.samples_format.lower() in ("jpg", "jpeg"): + image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality) + else: + image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality) + + else: + raise HTTPException(status_code=500, detail="Invalid image format") - image.save( - output_bytes, "PNG", pnginfo=(metadata if use_metadata else None) - ) bytes_data = output_bytes.getvalue() + return base64.b64encode(bytes_data) def api_middleware(app: FastAPI): diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 8514fea7..09d8e605 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -67,7 +67,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, pp.image.info["postprocessing"] = infotext if save_output: - images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=pp.info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) + images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) if extras_mode != 2 or show_extras_results: outputs.append(pp.image) diff --git a/modules/shared.py b/modules/shared.py index 63b236c5..d7a18f6a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -432,6 +432,10 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"), })) +options_templates.update(options_section(('extra_networks', "Extra Networks"), { + "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, { "choices": ["cards", "thumbs"] }), +})) + options_templates.update(options_section(('ui', "User interface"), { "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index af2b8071..2ddac3d8 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -26,6 +26,7 @@ class ExtraNetworksPage: pass def create_html(self, tabname): + view = shared.opts.extra_networks_default_view items_html = '' for item in self.list_items(): @@ -36,7 +37,7 @@ class ExtraNetworksPage: items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) res = f""" -