Merge branch 'dev' into gradio-theme-support

This commit is contained in:
AUTOMATIC1111 2023-04-29 12:45:43 +03:00 committed by GitHub
commit e6cbfcfe5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
53 changed files with 936 additions and 341 deletions

2
.gitignore vendored
View File

@ -32,4 +32,4 @@ notification.mp3
/extensions /extensions
/test/stdout.txt /test/stdout.txt
/test/stderr.txt /test/stderr.txt
/cache.json /cache.json*

View File

@ -13,9 +13,9 @@ A browser interface based on Gradio library for Stable Diffusion.
- Prompt Matrix - Prompt Matrix
- Stable Diffusion Upscale - Stable Diffusion Upscale
- Attention, specify parts of text that the model should pay more attention to - Attention, specify parts of text that the model should pay more attention to
- a man in a ((tuxedo)) - will pay more attention to tuxedo - a man in a `((tuxedo))` - will pay more attention to tuxedo
- a man in a (tuxedo:1.21) - alternative syntax - a man in a `(tuxedo:1.21)` - alternative syntax
- select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user) - select text and press `Ctrl+Up` or `Ctrl+Down` to automatically adjust attention to selected text (code contributed by anonymous user)
- Loopback, run img2img processing multiple times - Loopback, run img2img processing multiple times
- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters - X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
- Textual Inversion - Textual Inversion
@ -28,7 +28,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- CodeFormer, face restoration tool as an alternative to GFPGAN - CodeFormer, face restoration tool as an alternative to GFPGAN
- RealESRGAN, neural network upscaler - RealESRGAN, neural network upscaler
- ESRGAN, neural network upscaler with a lot of third party models - ESRGAN, neural network upscaler with a lot of third party models
- SwinIR and Swin2SR([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers - SwinIR and Swin2SR ([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers
- LDSR, Latent diffusion super resolution upscaling - LDSR, Latent diffusion super resolution upscaling
- Resizing aspect ratio options - Resizing aspect ratio options
- Sampling method selection - Sampling method selection
@ -46,7 +46,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- drag and drop an image/text-parameters to promptbox - drag and drop an image/text-parameters to promptbox
- Read Generation Parameters Button, loads parameters in promptbox to UI - Read Generation Parameters Button, loads parameters in promptbox to UI
- Settings page - Settings page
- Running arbitrary python code from UI (must run with --allow-code to enable) - Running arbitrary python code from UI (must run with `--allow-code` to enable)
- Mouseover hints for most UI elements - Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config - Possible to change defaults/mix/max/step values for UI elements via text config
- Tiling support, a checkbox to create images that can be tiled like textures - Tiling support, a checkbox to create images that can be tiled like textures
@ -69,7 +69,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2` - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens) - No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
- DeepDanbooru integration, creates danbooru style tags for anime prompts - DeepDanbooru integration, creates danbooru style tags for anime prompts
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args) - [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add `--xformers` to commandline args)
- via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI - via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI
- Generate forever option - Generate forever option
- Training tab - Training tab
@ -78,11 +78,11 @@ A browser interface based on Gradio library for Stable Diffusion.
- Clip skip - Clip skip
- Hypernetworks - Hypernetworks
- Loras (same as Hypernetworks but more pretty) - Loras (same as Hypernetworks but more pretty)
- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt. - A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt
- Can select to load a different VAE from settings screen - Can select to load a different VAE from settings screen
- Estimated completion time in progress bar - Estimated completion time in progress bar
- API - API
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions - [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions - [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
@ -91,7 +91,6 @@ A browser interface based on Gradio library for Stable Diffusion.
- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64 - Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
- Now with a license! - Now with a license!
- Reorder elements in the UI from settings screen - Reorder elements in the UI from settings screen
-
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
@ -101,7 +100,7 @@ Alternatively, use online services (like Google Colab):
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services) - [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
### Automatic Installation on Windows ### Automatic Installation on Windows
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH" 1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH".
2. Install [git](https://git-scm.com/download/win). 2. Install [git](https://git-scm.com/download/win).
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`. 3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user. 4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
@ -159,4 +158,4 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Security advice - RyotaK - Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You) - (You)

View File

@ -4,8 +4,8 @@ channels:
- defaults - defaults
dependencies: dependencies:
- python=3.10 - python=3.10
- pip=22.2.2 - pip=23.0
- cudatoolkit=11.3 - cudatoolkit=11.8
- pytorch=1.12.1 - pytorch=2.0
- torchvision=0.13.1 - torchvision=0.15
- numpy=1.23.1 - numpy=1.23

View File

@ -8,7 +8,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_lora additional = shared.opts.sd_lora
if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0: if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))

View File

@ -2,20 +2,34 @@ import glob
import os import os
import re import re
import torch import torch
from typing import Union
from modules import shared, devices, sd_models, errors from modules import shared, devices, sd_models, errors
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
re_digits = re.compile(r"\d+") re_digits = re.compile(r"\d+")
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") re_compiled = {}
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") suffix_conversion = {
"attentions": {},
"resnets": {
"conv1": "in_layers_2",
"conv2": "out_layers_3",
"time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection",
}
}
def convert_diffusers_name_to_compvis(key): def convert_diffusers_name_to_compvis(key, is_sd2):
def match(match_list, regex): def match(match_list, regex_text):
regex = re_compiled.get(regex_text)
if regex is None:
regex = re.compile(regex_text)
re_compiled[regex_text] = regex
r = re.match(regex, key) r = re.match(regex, key)
if not r: if not r:
return False return False
@ -26,16 +40,33 @@ def convert_diffusers_name_to_compvis(key):
m = [] m = []
if match(m, re_unet_down_blocks): if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
if match(m, re_unet_mid_blocks): if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
return f"diffusion_model_middle_block_1_{m[1]}" suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
if match(m, re_unet_up_blocks): if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
if is_sd2:
if 'mlp_fc1' in m[1]:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
elif 'mlp_fc2' in m[1]:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
else:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
if match(m, re_text_block):
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
return key return key
@ -101,15 +132,22 @@ def load_lora(name, filename):
sd = sd_models.read_state_dict(filename) sd = sd_models.read_state_dict(filename)
keys_failed_to_match = [] keys_failed_to_match = {}
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
for key_diffusers, weight in sd.items(): for key_diffusers, weight in sd.items():
fullkey = convert_diffusers_name_to_compvis(key_diffusers) key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)
key, lora_key = fullkey.split(".", 1) key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)
sd_module = shared.sd_model.lora_layer_mapping.get(key, None) sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
if sd_module is None: if sd_module is None:
keys_failed_to_match.append(key_diffusers) m = re_x_proj.match(key)
if m:
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
if sd_module is None:
keys_failed_to_match[key_diffusers] = key
continue continue
lora_module = lora.modules.get(key, None) lora_module = lora.modules.get(key, None)
@ -123,15 +161,21 @@ def load_lora(name, filename):
if type(sd_module) == torch.nn.Linear: if type(sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.MultiheadAttention:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d: elif type(sd_module) == torch.nn.Conv2d:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
else: else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
with torch.no_grad(): with torch.no_grad():
module.weight.copy_(weight) module.weight.copy_(weight)
module.to(device=devices.device, dtype=devices.dtype) module.to(device=devices.cpu, dtype=devices.dtype)
if lora_key == "lora_up.weight": if lora_key == "lora_up.weight":
lora_module.up = module lora_module.up = module
@ -177,29 +221,120 @@ def load_loras(names, multipliers=None):
loaded_loras.append(lora) loaded_loras.append(lora)
def lora_forward(module, input, res): def lora_calc_updown(lora, module, target):
input = devices.cond_cast_unet(input) with torch.no_grad():
if len(loaded_loras) == 0: up = module.up.weight.to(target.device, dtype=target.dtype)
return res down = module.down.weight.to(target.device, dtype=target.dtype)
lora_layer_name = getattr(module, 'lora_layer_name', None) if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
for lora in loaded_loras: updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
module = lora.modules.get(lora_layer_name, None) else:
if module is not None: updown = up @ down
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return updown
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of Loras to the weights of torch layer self.
If weights already have this particular set of loras applied, does nothing.
If not, restores orginal weights from backup and alters weights according to loras.
"""
lora_layer_name = getattr(self, 'lora_layer_name', None)
if lora_layer_name is None:
return
current_names = getattr(self, "lora_current_names", ())
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
weights_backup = getattr(self, "lora_weights_backup", None)
if weights_backup is None:
if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
else:
weights_backup = self.weight.to(devices.cpu, copy=True)
self.lora_weights_backup = weights_backup
if current_names != wanted_names:
if weights_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else: else:
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) self.weight.copy_(weights_backup)
return res for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is not None and hasattr(self, 'weight'):
self.weight += lora_calc_updown(lora, module, self.weight)
continue
module_q = lora.modules.get(lora_layer_name + "_q_proj", None)
module_k = lora.modules.get(lora_layer_name + "_k_proj", None)
module_v = lora.modules.get(lora_layer_name + "_v_proj", None)
module_out = lora.modules.get(lora_layer_name + "_out_proj", None)
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight)
updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight)
updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight)
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
self.in_proj_weight += updown_qkv
self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight)
continue
if module is None:
continue
print(f'failed to calculate lora weights for layer {lora_layer_name}')
setattr(self, "lora_current_names", wanted_names)
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
setattr(self, "lora_current_names", ())
setattr(self, "lora_weights_backup", None)
def lora_Linear_forward(self, input): def lora_Linear_forward(self, input):
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) lora_apply_weights(self)
return torch.nn.Linear_forward_before_lora(self, input)
def lora_Linear_load_state_dict(self, *args, **kwargs):
lora_reset_cached_weight(self)
return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
def lora_Conv2d_forward(self, input): def lora_Conv2d_forward(self, input):
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) lora_apply_weights(self)
return torch.nn.Conv2d_forward_before_lora(self, input)
def lora_Conv2d_load_state_dict(self, *args, **kwargs):
lora_reset_cached_weight(self)
return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
def lora_MultiheadAttention_forward(self, *args, **kwargs):
lora_apply_weights(self)
return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs)
def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
lora_reset_cached_weight(self)
return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs)
def list_available_loras(): def list_available_loras():
@ -212,7 +347,7 @@ def list_available_loras():
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \ glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True) glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
for filename in sorted(candidates): for filename in sorted(candidates, key=str.lower):
if os.path.isdir(filename): if os.path.isdir(filename):
continue continue

View File

@ -9,7 +9,11 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload(): def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora
def before_ui(): def before_ui():
@ -20,11 +24,27 @@ def before_ui():
if not hasattr(torch.nn, 'Linear_forward_before_lora'): if not hasattr(torch.nn, 'Linear_forward_before_lora'):
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):
torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'):
torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'):
torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict
torch.nn.Linear.forward = lora.lora_Linear_forward torch.nn.Linear.forward = lora.lora_Linear_forward
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload) script_callbacks.on_script_unloaded(unload)
@ -32,7 +52,5 @@ script_callbacks.on_before_ui(before_ui)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
})) }))

View File

@ -5,11 +5,15 @@ import traceback
import PIL.Image import PIL.Image
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import devices, modelloader from modules import devices, modelloader
from scunet_model_arch import SCUNet as net from scunet_model_arch import SCUNet as net
from modules.shared import opts
from modules import images
class UpscalerScuNET(modules.upscaler.Upscaler): class UpscalerScuNET(modules.upscaler.Upscaler):
@ -42,28 +46,78 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers.append(scaler_data2) scalers.append(scaler_data2)
self.scalers = scalers self.scalers = scalers
def do_upscale(self, img: PIL.Image, selected_file): @staticmethod
@torch.no_grad()
def tiled_inference(img, model):
# test the image tile by tile
h, w = img.shape[2:]
tile = opts.SCUNET_tile
tile_overlap = opts.SCUNET_tile_overlap
if tile == 0:
return model(img)
device = devices.get_device_for('scunet')
assert tile % 8 == 0, "tile size should be a multiple of window_size"
sf = 1
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
def do_upscale(self, img: PIL.Image.Image, selected_file):
torch.cuda.empty_cache() torch.cuda.empty_cache()
model = self.load_model(selected_file) model = self.load_model(selected_file)
if model is None: if model is None:
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
return img return img
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
img = np.array(img) tile = opts.SCUNET_tile
img = img[:, :, ::-1] h, w = img.height, img.width
img = np.moveaxis(img, 2, 0) / 255 np_img = np.array(img)
img = torch.from_numpy(img).float() np_img = np_img[:, :, ::-1] # RGB to BGR
img = img.unsqueeze(0).to(device) np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore
with torch.no_grad(): if tile > h or tile > w:
output = model(img) _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() _img[:, :, :h, :w] = torch_img # pad image
output = 255. * np.moveaxis(output, 0, 2) torch_img = _img
output = output.astype(np.uint8)
output = output[:, :, ::-1] torch_output = self.tiled_inference(torch_img, model).squeeze(0)
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output
torch.cuda.empty_cache() torch.cuda.empty_cache()
return PIL.Image.fromarray(output, 'RGB')
output = np_output.transpose((1, 2, 0)) # CHW to HWC
output = output[:, :, ::-1] # BGR to RGB
return PIL.Image.fromarray((output * 255).astype(np.uint8))
def load_model(self, path: str): def load_model(self, path: str):
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
@ -84,4 +138,3 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
model = model.to(device) model = model.to(device)
return model return model

View File

@ -12,7 +12,7 @@ function dimensionChange(e, is_width, is_height){
currentHeight = e.target.value*1.0 currentHeight = e.target.value*1.0
} }
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block";
if(!inImg2img){ if(!inImg2img){
return; return;
@ -22,7 +22,7 @@ function dimensionChange(e, is_width, is_height){
var tabIndex = get_tab_index('mode_img2img') var tabIndex = get_tab_index('mode_img2img')
if(tabIndex == 0){ // img2img if(tabIndex == 0){ // img2img
targetElement = gradioApp().querySelector('div[data-testid=image] img'); targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img');
} else if(tabIndex == 1){ //Sketch } else if(tabIndex == 1){ //Sketch
targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img'); targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
} else if(tabIndex == 2){ // Inpaint } else if(tabIndex == 2){ // Inpaint
@ -30,7 +30,7 @@ function dimensionChange(e, is_width, is_height){
} else if(tabIndex == 3){ // Inpaint sketch } else if(tabIndex == 3){ // Inpaint sketch
targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img'); targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
} }
if(targetElement){ if(targetElement){
@ -38,7 +38,7 @@ function dimensionChange(e, is_width, is_height){
if(!arPreviewRect){ if(!arPreviewRect){
arPreviewRect = document.createElement('div') arPreviewRect = document.createElement('div')
arPreviewRect.id = "imageARPreview"; arPreviewRect.id = "imageARPreview";
gradioApp().getRootNode().appendChild(arPreviewRect) gradioApp().appendChild(arPreviewRect)
} }
@ -91,23 +91,26 @@ onUiUpdate(function(){
if(arPreviewRect){ if(arPreviewRect){
arPreviewRect.style.display = 'none'; arPreviewRect.style.display = 'none';
} }
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) var tabImg2img = gradioApp().querySelector("#tab_img2img");
if(inImg2img){ if (tabImg2img) {
let inputs = gradioApp().querySelectorAll('input'); var inImg2img = tabImg2img.style.display == "block";
inputs.forEach(function(e){ if(inImg2img){
var is_width = e.parentElement.id == "img2img_width" let inputs = gradioApp().querySelectorAll('input');
var is_height = e.parentElement.id == "img2img_height" inputs.forEach(function(e){
var is_width = e.parentElement.id == "img2img_width"
var is_height = e.parentElement.id == "img2img_height"
if((is_width || is_height) && !e.classList.contains('scrollwatch')){ if((is_width || is_height) && !e.classList.contains('scrollwatch')){
e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} ) e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
e.classList.add('scrollwatch') e.classList.add('scrollwatch')
} }
if(is_width){ if(is_width){
currentWidth = e.value*1.0 currentWidth = e.value*1.0
} }
if(is_height){ if(is_height){
currentHeight = e.value*1.0 currentHeight = e.value*1.0
} }
}) })
} }
}
}); });

View File

@ -161,14 +161,6 @@ addContextMenuEventListener = initResponse[2];
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever) appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever) appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#roll','Roll three',
function(){
let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
setTimeout(function(){rollbutton.click()},100)
setTimeout(function(){rollbutton.click()},200)
setTimeout(function(){rollbutton.click()},300)
}
)
})(); })();
//End example Context Menu Items //End example Context Menu Items

View File

@ -17,7 +17,7 @@ function keyupEditAttention(event){
// Find opening parenthesis around current cursor // Find opening parenthesis around current cursor
const before = text.substring(0, selectionStart); const before = text.substring(0, selectionStart);
let beforeParen = before.lastIndexOf(OPEN); let beforeParen = before.lastIndexOf(OPEN);
if (beforeParen == -1) return false; if (beforeParen == -1) return false;
let beforeParenClose = before.lastIndexOf(CLOSE); let beforeParenClose = before.lastIndexOf(CLOSE);
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1); beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
@ -27,7 +27,7 @@ function keyupEditAttention(event){
// Find closing parenthesis around current cursor // Find closing parenthesis around current cursor
const after = text.substring(selectionStart); const after = text.substring(selectionStart);
let afterParen = after.indexOf(CLOSE); let afterParen = after.indexOf(CLOSE);
if (afterParen == -1) return false; if (afterParen == -1) return false;
let afterParenOpen = after.indexOf(OPEN); let afterParenOpen = after.indexOf(OPEN);
while (afterParenOpen !== -1 && afterParen > afterParenOpen) { while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
afterParen = after.indexOf(CLOSE, afterParen + 1); afterParen = after.indexOf(CLOSE, afterParen + 1);
@ -43,10 +43,28 @@ function keyupEditAttention(event){
target.setSelectionRange(selectionStart, selectionEnd); target.setSelectionRange(selectionStart, selectionEnd);
return true; return true;
} }
function selectCurrentWord(){
if (selectionStart !== selectionEnd) return false;
const delimiters = opts.keyedit_delimiters + " \r\n\t";
// seek backward until to find beggining
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
selectionStart--;
}
// seek forward to find end
while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
selectionEnd++;
}
// If the user hasn't selected anything, let's select their current parenthesis block target.setSelectionRange(selectionStart, selectionEnd);
if(! selectCurrentParenthesisBlock('<', '>')){ return true;
selectCurrentParenthesisBlock('(', ')') }
// If the user hasn't selected anything, let's select their current parenthesis block or word
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
selectCurrentWord();
} }
event.preventDefault(); event.preventDefault();
@ -81,7 +99,13 @@ function keyupEditAttention(event){
weight = parseFloat(weight.toPrecision(12)); weight = parseFloat(weight.toPrecision(12));
if(String(weight).length == 1) weight += ".0" if(String(weight).length == 1) weight += ".0"
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1); if (closeCharacter == ')' && weight == 1) {
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
selectionStart--;
selectionEnd--;
} else {
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
}
target.focus(); target.focus();
target.value = text; target.value = text;
@ -93,4 +117,4 @@ function keyupEditAttention(event){
addEventListener('keydown', (event) => { addEventListener('keydown', (event) => {
keyupEditAttention(event); keyupEditAttention(event);
}); });

View File

@ -1,5 +1,5 @@
function extensions_apply(_, _){ function extensions_apply(_, _, disable_all){
var disable = [] var disable = []
var update = [] var update = []
@ -13,10 +13,10 @@ function extensions_apply(_, _){
restart_reload() restart_reload()
return [JSON.stringify(disable), JSON.stringify(update)] return [JSON.stringify(disable), JSON.stringify(update), disable_all]
} }
function extensions_check(){ function extensions_check(_, _){
var disable = [] var disable = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){

View File

@ -16,7 +16,7 @@ onUiUpdate(function(){
let modalObserver = new MutationObserver(function(mutations) { let modalObserver = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) { mutations.forEach(function(mutationRecord) {
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText let selectedTab = gradioApp().querySelector('#tabs div button')?.innerText
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img') if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
gradioApp().getElementById(selectedTab+"_generation_info_button").click() gradioApp().getElementById(selectedTab+"_generation_info_button").click()
}); });

View File

@ -21,8 +21,7 @@ titles = {
"\u{1f5d1}\ufe0f": "Clear prompt", "\u{1f5d1}\ufe0f": "Clear prompt",
"\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4cb}": "Apply selected styles to current prompt",
"\u{1f4d2}": "Paste available values into the field", "\u{1f4d2}": "Paste available values into the field",
"\u{1f3b4}": "Show extra networks", "\u{1f3b4}": "Show/hide extra networks",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",

View File

@ -32,13 +32,7 @@ function negmod(n, m) {
function updateOnBackgroundChange() { function updateOnBackgroundChange() {
const modalImage = gradioApp().getElementById("modalImage") const modalImage = gradioApp().getElementById("modalImage")
if (modalImage && modalImage.offsetParent) { if (modalImage && modalImage.offsetParent) {
let allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2") let currentButton = selected_gallery_button();
let currentButton = null
allcurrentButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
currentButton = elem;
}
})
if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
modalImage.src = currentButton.children[0].src; modalImage.src = currentButton.children[0].src;
@ -50,22 +44,10 @@ function updateOnBackgroundChange() {
} }
function modalImageSwitch(offset) { function modalImageSwitch(offset) {
var allgalleryButtons = gradioApp().querySelectorAll(".gradio-gallery .thumbnail-item") var galleryButtons = all_gallery_buttons();
var galleryButtons = []
allgalleryButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
galleryButtons.push(elem);
}
})
if (galleryButtons.length > 1) { if (galleryButtons.length > 1) {
var allcurrentButtons = gradioApp().querySelectorAll(".gradio-gallery .thumbnail-item.selected") var currentButton = selected_gallery_button();
var currentButton = null
allcurrentButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
currentButton = elem;
}
})
var result = -1 var result = -1
galleryButtons.forEach(function(v, i) { galleryButtons.forEach(function(v, i) {
@ -269,8 +251,11 @@ document.addEventListener("DOMContentLoaded", function() {
modal.appendChild(modalNext) modal.appendChild(modalNext)
gradioApp().appendChild(modal) try {
gradioApp().appendChild(modal);
} catch (e) {
gradioApp().body.appendChild(modal);
}
document.body.appendChild(modal); document.body.appendChild(modal);

View File

@ -138,7 +138,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
return return
} }
if(elapsedFromStart > 5 && !res.queued && !res.active){ if(elapsedFromStart > 40 && !res.queued && !res.active){
removeProgressBar() removeProgressBar()
return return
} }

View File

@ -7,9 +7,31 @@ function set_theme(theme){
} }
} }
function all_gallery_buttons() {
var allGalleryButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnails > .thumbnail-item.thumbnail-small');
var visibleGalleryButtons = [];
allGalleryButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
visibleGalleryButtons.push(elem);
}
})
return visibleGalleryButtons;
}
function selected_gallery_button() {
var allCurrentButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnail-item.thumbnail-small.selected');
var visibleCurrentButton = null;
allCurrentButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
visibleCurrentButton = elem;
}
})
return visibleCurrentButton;
}
function selected_gallery_index(){ function selected_gallery_index(){
var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item') var buttons = all_gallery_buttons();
var button = gradioApp().querySelector('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item.\\!ring-2') var button = selected_gallery_button();
var result = -1 var result = -1
buttons.forEach(function(v, i){ if(v==button) { result = i } }) buttons.forEach(function(v, i){ if(v==button) { result = i } })
@ -18,14 +40,18 @@ function selected_gallery_index(){
} }
function extract_image_from_gallery(gallery){ function extract_image_from_gallery(gallery){
if(gallery.length == 1){ if (gallery.length == 0){
return [gallery[0]] return [null];
}
if (gallery.length == 1){
return [gallery[0]];
} }
index = selected_gallery_index() index = selected_gallery_index()
if (index < 0 || index >= gallery.length){ if (index < 0 || index >= gallery.length){
return [null] // Use the first image in the gallery as the default
index = 0;
} }
return [gallery[index]]; return [gallery[index]];

View File

@ -121,12 +121,12 @@ def run_python(code, desc=None, errdesc=None):
return run(f'"{python}" -c "{code}"', desc, errdesc) return run(f'"{python}" -c "{code}"', desc, errdesc)
def run_pip(args, desc=None): def run_pip(args, desc=None, live=False):
if skip_install: if skip_install:
return return
index_url_line = f' --index-url {index_url}' if index_url != '' else '' index_url_line = f' --index-url {index_url}' if index_url != '' else ''
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
def check_run_python(code): def check_run_python(code):
@ -206,6 +206,10 @@ def list_extensions(settings_file):
print(e, file=sys.stderr) print(e, file=sys.stderr)
disabled_extensions = set(settings.get('disabled_extensions', [])) disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none')
if disable_all_extensions != 'none':
return []
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions] return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
@ -221,10 +225,10 @@ def run_extensions_installers(settings_file):
def prepare_environment(): def prepare_environment():
global skip_install global skip_install
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117") torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.0 torchvision==0.15.1 --index-url https://download.pytorch.org/whl/cu118")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425') xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
@ -235,7 +239,7 @@ def prepare_environment():
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
@ -267,7 +271,7 @@ def prepare_environment():
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers: if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
if platform.system() == "Windows": if platform.system() == "Windows":
if platform.python_version().startswith("3.10"): if platform.python_version().startswith("3.10"):
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers") run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
else: else:
print("Installation of xformers is not supported in this version of Python.") print("Installation of xformers is not supported in this version of Python.")
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness") print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
@ -292,7 +296,7 @@ def prepare_environment():
if not os.path.isfile(requirements_file): if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file) requirements_file = os.path.join(script_path, requirements_file)
run_pip(f"install -r \"{requirements_file}\"", "requirements for Web UI") run_pip(f"install -r \"{requirements_file}\"", "requirements")
run_extensions_installers(settings_file=args.ui_settings_file) run_extensions_installers(settings_file=args.ui_settings_file)

Binary file not shown.

View File

@ -3,9 +3,9 @@ import io
import time import time
import datetime import datetime
import uvicorn import uvicorn
import gradio as gr
from threading import Lock from threading import Lock
from io import BytesIO from io import BytesIO
from gradio.processing_utils import decode_base64_to_file
from fastapi import APIRouter, Depends, FastAPI, Request, Response from fastapi import APIRouter, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
@ -197,6 +197,9 @@ class Api:
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
def add_api_route(self, path: str, endpoint, **kwargs): def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth: if shared.cmd_opts.api_auth:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs) return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
@ -230,7 +233,7 @@ class Api:
script_idx = script_name_to_index(script_name, script_runner.scripts) script_idx = script_name_to_index(script_name, script_runner.scripts)
return script_runner.scripts[script_idx] return script_runner.scripts[script_idx]
def init_script_args(self, request, selectable_scripts, selectable_idx, script_runner): def init_default_script_args(self, script_runner):
#find max idx from the scripts in runner and generate a none array to init script_args #find max idx from the scripts in runner and generate a none array to init script_args
last_arg_index = 1 last_arg_index = 1
for script in script_runner.scripts: for script in script_runner.scripts:
@ -238,13 +241,24 @@ class Api:
last_arg_index = script.args_to last_arg_index = script.args_to
# None everywhere except position 0 to initialize script args # None everywhere except position 0 to initialize script args
script_args = [None]*last_arg_index script_args = [None]*last_arg_index
script_args[0] = 0
# get default values
with gr.Blocks(): # will throw errors calling ui function without this
for script in script_runner.scripts:
if script.ui(script.is_img2img):
ui_default_values = []
for elem in script.ui(script.is_img2img):
ui_default_values.append(elem.value)
script_args[script.args_from:script.args_to] = ui_default_values
return script_args
def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner):
script_args = default_script_args.copy()
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run() # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
if selectable_scripts: if selectable_scripts:
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
script_args[0] = selectable_idx + 1 script_args[0] = selectable_idx + 1
else:
# when [0] = 0 no selectable script to run
script_args[0] = 0
# Now check for always on scripts # Now check for always on scripts
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0): if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
@ -265,6 +279,8 @@ class Api:
if not script_runner.scripts: if not script_runner.scripts:
script_runner.initialize_scripts(False) script_runner.initialize_scripts(False)
ui.create_ui() ui.create_ui()
if not self.default_script_arg_txt2img:
self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
populate = txt2imgreq.copy(update={ # Override __init__ params populate = txt2imgreq.copy(update={ # Override __init__ params
@ -280,7 +296,7 @@ class Api:
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None) args.pop('alwayson_scripts', None)
script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner) script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
send_images = args.pop('send_images', True) send_images = args.pop('send_images', True)
args.pop('save_images', None) args.pop('save_images', None)
@ -317,6 +333,8 @@ class Api:
if not script_runner.scripts: if not script_runner.scripts:
script_runner.initialize_scripts(True) script_runner.initialize_scripts(True)
ui.create_ui() ui.create_ui()
if not self.default_script_arg_img2img:
self.default_script_arg_img2img = self.init_default_script_args(script_runner)
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
populate = img2imgreq.copy(update={ # Override __init__ params populate = img2imgreq.copy(update={ # Override __init__ params
@ -334,7 +352,7 @@ class Api:
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None) args.pop('alwayson_scripts', None)
script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner) script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
send_images = args.pop('send_images', True) send_images = args.pop('send_images', True)
args.pop('save_images', None) args.pop('save_images', None)
@ -376,16 +394,11 @@ class Api:
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
reqDict = setUpscalers(req) reqDict = setUpscalers(req)
def prepareFiles(file): image_list = reqDict.pop('imageList', [])
file = decode_base64_to_file(file.data, file_path=file.name) image_folder = [decode_base64_to_image(x.data) for x in image_list]
file.orig_name = file.name
return file
reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
reqDict.pop('imageList')
with self.queue_lock: with self.queue_lock:
result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])

View File

@ -4,6 +4,7 @@ from modules.paths_internal import models_path, script_path, data_path, extensio
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui
parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program") parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program")
parser.add_argument("--skip-python-version-check", action='store_true', help="launch.py argument: do not check python version") parser.add_argument("--skip-python-version-check", action='store_true', help="launch.py argument: do not check python version")
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly") parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")

View File

@ -92,14 +92,18 @@ def cond_cast_float(input):
def randn(seed, shape): def randn(seed, shape):
from modules.shared import opts
torch.manual_seed(seed) torch.manual_seed(seed)
if device.type == 'mps': if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
def randn_without_seed(shape): def randn_without_seed(shape):
if device.type == 'mps': from modules.shared import opts
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device) return torch.randn(shape, device=device)

View File

@ -5,16 +5,22 @@ import traceback
import time import time
import git import git
from modules import paths, shared from modules import shared
from modules.paths_internal import extensions_dir, extensions_builtin_dir from modules.paths_internal import extensions_dir, extensions_builtin_dir
extensions = [] extensions = []
if not os.path.exists(paths.extensions_dir): if not os.path.exists(extensions_dir):
os.makedirs(paths.extensions_dir) os.makedirs(extensions_dir)
def active(): def active():
return [x for x in extensions if x.enabled] if shared.opts.disable_all_extensions == "all":
return []
elif shared.opts.disable_all_extensions == "extra":
return [x for x in extensions if x.enabled and x.is_builtin]
else:
return [x for x in extensions if x.enabled]
class Extension: class Extension:
@ -26,21 +32,29 @@ class Extension:
self.can_update = False self.can_update = False
self.is_builtin = is_builtin self.is_builtin = is_builtin
self.version = '' self.version = ''
self.remote = None
self.have_info_from_repo = False
def read_info_from_repo(self):
if self.have_info_from_repo:
return
self.have_info_from_repo = True
repo = None repo = None
try: try:
if os.path.exists(os.path.join(path, ".git")): if os.path.exists(os.path.join(self.path, ".git")):
repo = git.Repo(path) repo = git.Repo(self.path)
except Exception: except Exception:
print(f"Error reading github repository info from {path}:", file=sys.stderr) print(f"Error reading github repository info from {self.path}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
if repo is None or repo.bare: if repo is None or repo.bare:
self.remote = None self.remote = None
else: else:
try: try:
self.remote = next(repo.remote().urls, None)
self.status = 'unknown' self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
head = repo.head.commit head = repo.head.commit
ts = time.asctime(time.gmtime(repo.head.commit.committed_date)) ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
self.version = f'{head.hexsha[:8]} ({ts})' self.version = f'{head.hexsha[:8]} ({ts})'
@ -85,11 +99,16 @@ class Extension:
def list_extensions(): def list_extensions():
extensions.clear() extensions.clear()
if not os.path.isdir(paths.extensions_dir): if not os.path.isdir(extensions_dir):
return return
if shared.opts.disable_all_extensions == "all":
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
extension_paths = [] extension_paths = []
for dirname in [paths.extensions_dir, paths.extensions_builtin_dir]: for dirname in [extensions_dir, extensions_builtin_dir]:
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
return return
@ -98,9 +117,8 @@ def list_extensions():
if not os.path.isdir(path): if not os.path.isdir(path):
continue continue
extension_paths.append((extension_dirname, path, dirname == paths.extensions_builtin_dir)) extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
for dirname, path, is_builtin in extension_paths: for dirname, path, is_builtin in extension_paths:
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin) extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
extensions.append(extension) extensions.append(extension)

View File

@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_hypernetwork additional = shared.opts.sd_hypernetwork
if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))

View File

@ -284,6 +284,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
restore_old_hires_fix_params(res) restore_old_hires_fix_params(res)
# Missing RNG means the default was set, which is GPU RNG
if "RNG" not in res:
res["RNG"] = "GPU"
return res return res
@ -304,6 +308,7 @@ infotext_to_setting_name_mapping = [
('UniPC skip type', 'uni_pc_skip_type'), ('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'), ('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'), ('UniPC lower order final', 'uni_pc_lower_order_final'),
('RNG', 'randn_source'),
] ]

View File

@ -312,7 +312,7 @@ class Hypernetwork:
def list_hypernetworks(path): def list_hypernetworks(path):
res = {} res = {}
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)): for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True), key=str.lower):
name = os.path.splitext(os.path.basename(filename))[0] name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed. # Prevent a hypothetical "None.pt" from being listed.
if name != "None": if name != "None":

View File

@ -261,9 +261,12 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
if scale > 1.0: if scale > 1.0:
upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name] upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}" if len(upscalers) == 0:
upscaler = shared.sd_upscalers[0]
print(f"could not find upscaler named {upscaler_name or '<empty string>'}, using {upscaler.name} as a fallback")
else:
upscaler = upscalers[0]
upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path) im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
if im.width != w or im.height != h: if im.width != w or im.height != h:
@ -349,6 +352,7 @@ class FilenameGenerator:
'prompt_no_styles': lambda self: self.prompt_no_style(), 'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False), 'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
'prompt_words': lambda self: self.prompt_words(), 'prompt_words': lambda self: self.prompt_words(),
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
} }
default_time_format = '%Y%m%d%H%M%S' default_time_format = '%Y%m%d%H%M%S'

View File

@ -151,13 +151,14 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
override_settings=override_settings, override_settings=override_settings,
) )
p.scripts = modules.scripts.scripts_txt2img p.scripts = modules.scripts.scripts_img2img
p.script_args = args p.script_args = args
if shared.cmd_opts.enable_console_prompts: if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out) print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
p.extra_generation_params["Mask blur"] = mask_blur if mask:
p.extra_generation_params["Mask blur"] = mask_blur
if is_batch: if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"

View File

@ -32,7 +32,7 @@ def download_default_clip_interrogate_categories(content_dir):
category_types = ["artists", "flavors", "mediums", "movements"] category_types = ["artists", "flavors", "mediums", "movements"]
try: try:
os.makedirs(tmpdir) os.makedirs(tmpdir, exist_ok=True)
for category_type in category_types: for category_type in category_types:
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt")) torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
os.rename(tmpdir, content_dir) os.rename(tmpdir, content_dir)
@ -41,7 +41,7 @@ def download_default_clip_interrogate_categories(content_dir):
errors.display(e, "downloading default CLIP interrogate categories") errors.display(e, "downloading default CLIP interrogate categories")
finally: finally:
if os.path.exists(tmpdir): if os.path.exists(tmpdir):
os.remove(tmpdir) os.removedirs(tmpdir)
class InterrogateModels: class InterrogateModels:

View File

@ -55,12 +55,12 @@ def setup_for_low_vram(sd_model, use_medvram):
if hasattr(sd_model.cond_stage_model, 'model'): if hasattr(sd_model.cond_stage_model, 'model'):
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
# remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU. # send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
sd_model.to(devices.device) sd_model.to(devices.device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
# register hooks for those the first three models # register hooks for those the first three models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
@ -69,6 +69,8 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model.first_stage_model.decode = first_stage_model_decode_wrap sd_model.first_stage_model.decode = first_stage_model_decode_wrap
if sd_model.depth_model: if sd_model.depth_model:
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu) sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
if sd_model.embedder:
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if hasattr(sd_model.cond_stage_model, 'model'): if hasattr(sd_model.cond_stage_model, 'model'):

View File

@ -18,9 +18,15 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
if extras_mode == 1: if extras_mode == 1:
for img in image_folder: for img in image_folder:
image = Image.open(img) if isinstance(img, Image.Image):
image = img
fn = ''
else:
image = Image.open(os.path.abspath(img.name))
fn = os.path.splitext(img.orig_name)[0]
image_data.append(image) image_data.append(image)
image_names.append(os.path.splitext(img.orig_name)[0]) image_names.append(fn)
elif extras_mode == 2: elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
assert input_dir, 'input directory not selected' assert input_dir, 'input directory not selected'

View File

@ -3,6 +3,7 @@ import math
import os import os
import sys import sys
import warnings import warnings
import hashlib
import torch import torch
import numpy as np import numpy as np
@ -78,22 +79,28 @@ def apply_overlay(image, paste_loc, index, overlays):
def txt2img_image_conditioning(sd_model, x, width, height): def txt2img_image_conditioning(sd_model, x, width, height):
if sd_model.model.conditioning_key not in {'hybrid', 'concat'}: if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
# Dummy zero conditioning if we're not using inpainting model.
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
return image_conditioning
elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
else:
# Dummy zero conditioning if we're not using inpainting or unclip models.
# Still takes up a bit of memory, but no encoder call. # Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
return image_conditioning
class StableDiffusionProcessing: class StableDiffusionProcessing:
""" """
@ -190,6 +197,14 @@ class StableDiffusionProcessing:
return conditioning_image return conditioning_image
def unclip_image_conditioning(self, source_image):
c_adm = self.sd_model.embedder(source_image)
if self.sd_model.noise_augmentor is not None:
noise_level = 0 # TODO: Allow other noise levels?
c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
c_adm = torch.cat((c_adm, noise_level_emb), 1)
return c_adm
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None): def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
self.is_using_inpainting_conditioning = True self.is_using_inpainting_conditioning = True
@ -241,6 +256,9 @@ class StableDiffusionProcessing:
if self.sampler.conditioning_key in {'hybrid', 'concat'}: if self.sampler.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image)
# Dummy zero conditioning if we're not using inpainting or depth model. # Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
@ -459,6 +477,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip, "Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
"Init image hash": getattr(p, 'init_img_hash', None),
"RNG": (opts.randn_source if opts.randn_source != "GPU" else None)
} }
generation_params.update(p.extra_generation_params) generation_params.update(p.extra_generation_params)
@ -990,6 +1010,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.color_corrections = [] self.color_corrections = []
imgs = [] imgs = []
for img in self.init_images: for img in self.init_images:
# Save init image
if opts.save_init_img:
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
image = images.flatten(img, opts.img2img_background_color) image = images.flatten(img, opts.img2img_background_color)
if crop_region is None and self.resize_mode != 3: if crop_region is None and self.resize_mode != 3:

View File

@ -1,6 +1,5 @@
# this code is adapted from the script contributed by anon from /h/ # this code is adapted from the script contributed by anon from /h/
import io
import pickle import pickle
import collections import collections
import sys import sys
@ -12,11 +11,9 @@ import _codecs
import zipfile import zipfile
import re import re
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
def encode(*args): def encode(*args):
out = _codecs.encode(*args) out = _codecs.encode(*args)
return out return out
@ -27,7 +24,7 @@ class RestrictedUnpickler(pickle.Unpickler):
def persistent_load(self, saved_id): def persistent_load(self, saved_id):
assert saved_id[0] == 'storage' assert saved_id[0] == 'storage'
return TypedStorage() return TypedStorage(_internal=True)
def find_class(self, module, name): def find_class(self, module, name):
if self.extra_handler is not None: if self.extra_handler is not None:

View File

@ -553,3 +553,15 @@ def IOComponent_init(self, *args, **kwargs):
original_IOComponent_init = gr.components.IOComponent.__init__ original_IOComponent_init = gr.components.IOComponent.__init__
gr.components.IOComponent.__init__ = IOComponent_init gr.components.IOComponent.__init__ = IOComponent_init
def BlockContext_init(self, *args, **kwargs):
res = original_BlockContext_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
return res
original_BlockContext_init = gr.blocks.BlockContext.__init__
gr.blocks.BlockContext.__init__ = BlockContext_init

View File

@ -122,7 +122,7 @@ def list_models():
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list: for filename in sorted(model_list, key=str.lower):
checkpoint_info = CheckpointInfo(filename) checkpoint_info = CheckpointInfo(filename)
checkpoint_info.register() checkpoint_info.register()
@ -383,6 +383,14 @@ def repair_config(sd_config):
elif shared.cmd_opts.upcast_sampling: elif shared.cmd_opts.upcast_sampling:
sd_config.model.params.unet_config.params.use_fp16 = True sd_config.model.params.unet_config.params.use_fp16 = True
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
# For UnCLIP-L, override the hardcoded karlo directory
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
karlo_path = os.path.join(paths.models_path, 'karlo')
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'

View File

@ -14,6 +14,8 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
@ -65,9 +67,14 @@ def is_using_v_parameterization_for_sd2(state_dict):
def guess_model_config_from_state_dict(sd, filename): def guess_model_config_from_state_dict(sd, filename):
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
return config_depth_model return config_depth_model
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
return config_unclip
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
return config_unopenclip
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
if diffusion_model_input.shape[1] == 9: if diffusion_model_input.shape[1] == 9:

View File

@ -60,3 +60,13 @@ def store_latent(decoded):
class InterruptedException(BaseException): class InterruptedException(BaseException):
pass pass
if opts.randn_source == "CPU":
import torchsde._brownian.brownian_interval
def torchsde_randn(size, dtype, device, seed):
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
torchsde._brownian.brownian_interval._randn = torchsde_randn

View File

@ -70,8 +70,13 @@ class VanillaStableDiffusionSampler:
# Have to unwrap the inpainting conditioning here to perform pre-processing # Have to unwrap the inpainting conditioning here to perform pre-processing
image_conditioning = None image_conditioning = None
uc_image_conditioning = None
if isinstance(cond, dict): if isinstance(cond, dict):
image_conditioning = cond["c_concat"][0] if self.conditioning_key == "crossattn-adm":
image_conditioning = cond["c_adm"]
uc_image_conditioning = unconditional_conditioning["c_adm"]
else:
image_conditioning = cond["c_concat"][0]
cond = cond["c_crossattn"][0] cond = cond["c_crossattn"][0]
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
@ -98,8 +103,12 @@ class VanillaStableDiffusionSampler:
# Wrap the image conditioning back up since the DDIM code can accept the dict directly. # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
# Note that they need to be lists because it just concatenates them later. # Note that they need to be lists because it just concatenates them later.
if image_conditioning is not None: if image_conditioning is not None:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} if self.conditioning_key == "crossattn-adm":
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
else:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
return x, ts, cond, unconditional_conditioning return x, ts, cond, unconditional_conditioning
@ -176,8 +185,12 @@ class VanillaStableDiffusionSampler:
# Wrap the conditioning models with additional image conditioning for inpainting model # Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None: if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} if self.conditioning_key == "crossattn-adm":
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
else:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
@ -195,8 +208,12 @@ class VanillaStableDiffusionSampler:
# Wrap the conditioning models with additional image conditioning for inpainting model # Wrap the conditioning models with additional image conditioning for inpainting model
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
if image_conditioning is not None: if image_conditioning is not None:
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]} if self.conditioning_key == "crossattn-adm":
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]} conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
else:
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])

View File

@ -92,14 +92,21 @@ class CFGDenoiser(torch.nn.Module):
batch_size = len(conds_list) batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)] repeats = [len(conds_list[i]) for i in range(batch_size)]
if shared.sd_model.model.conditioning_key == "crossattn-adm":
image_uncond = torch.zeros_like(image_cond)
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
else:
image_uncond = image_cond
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
if not is_edit_model: if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
else: else:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
cfg_denoiser_callback(denoiser_params) cfg_denoiser_callback(denoiser_params)
@ -116,13 +123,13 @@ class CFGDenoiser(torch.nn.Module):
cond_in = torch.cat([tensor, uncond, uncond]) cond_in = torch.cat([tensor, uncond, uncond])
if shared.batch_cond_uncond: if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
else: else:
x_out = torch.zeros_like(x_in) x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size): for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset a = batch_offset
b = a + batch_size b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b]))
else: else:
x_out = torch.zeros_like(x_in) x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
@ -135,9 +142,9 @@ class CFGDenoiser(torch.nn.Module):
else: else:
c_crossattn = torch.cat([tensor[a:b]], uncond) c_crossattn = torch.cat([tensor[a:b]], uncond)
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]}) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
cfg_denoised_callback(denoised_params) cfg_denoised_callback(denoised_params)
@ -183,7 +190,7 @@ class TorchHijack:
if noise.shape == x.shape: if noise.shape == x.shape:
return noise return noise
if x.device.type == 'mps': if opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device) return torch.randn_like(x, device=devices.cpu).to(x.device)
else: else:
return torch.randn_like(x) return torch.randn_like(x)

View File

@ -40,6 +40,7 @@ restricted_opts = {
"outdir_grids", "outdir_grids",
"outdir_txt2img_grids", "outdir_txt2img_grids",
"outdir_save", "outdir_save",
"outdir_init_images"
} }
ui_reorder_categories = [ ui_reorder_categories = [
@ -269,6 +270,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"), "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"), "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
"save_init_img": OptionInfo(False, "Save init images when using img2img"),
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"), "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"), "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
@ -284,6 +286,7 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs), "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs), "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs), "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
})) }))
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
@ -299,6 +302,8 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
"SCUNET_tile": OptionInfo(256, "Tile size for SCUNET upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"SCUNET_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SCUNET upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
})) }))
options_templates.update(options_section(('face-restoration', "Face restoration"), { options_templates.update(options_section(('face-restoration', "Face restoration"), {
@ -347,6 +352,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}),
})) }))
options_templates.update(options_section(('compatibility', "Compatibility"), { options_templates.update(options_section(('compatibility', "Compatibility"), {
@ -377,7 +383,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"), "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"), "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"), "extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
})) }))
options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "User interface"), {
@ -398,6 +404,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"), "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"), "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}), "hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
@ -439,7 +446,8 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
})) }))
options_templates.update(options_section((None, "Hidden options"), { options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable those extensions"), "disabled_extensions": OptionInfo([], "Disable these extensions"),
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
})) }))
@ -675,7 +683,7 @@ mem_mon.start()
def listfiles(dirname): def listfiles(dirname):
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")] filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=str.lower) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)] return [file for file in filenames if os.path.isfile(file)]

View File

@ -70,17 +70,6 @@ def gr_show(visible=True):
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.wrap .z-20 svg { display:none!important; }
.wrap .z-20::before { content:"Loading..." }
.wrap.cover-bg .z-20::before { content:"" }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
.meta-text-center { display:none!important; }
"""
# Using constants for these since the variation selector isn't visible. # Using constants for these since the variation selector isn't visible.
# Important that they exactly match script.js for tooltip to work. # Important that they exactly match script.js for tooltip to work.
random_symbol = '\U0001f3b2\ufe0f' # 🎲️ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
@ -182,8 +171,8 @@ def create_seed_inputs(target_interface):
with FormRow(elem_id=target_interface + '_seed_row', variant="compact"): with FormRow(elem_id=target_interface + '_seed_row', variant="compact"):
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
seed.style(container=False) seed.style(container=False)
random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed') random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed', label='Random seed')
reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed') reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed', label='Reuse seed')
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
@ -479,7 +468,7 @@ def create_ui():
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn") res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
if opts.dimensions_and_batch_together: if opts.dimensions_and_batch_together:
with gr.Column(elem_id="txt2img_column_batch"): with gr.Column(elem_id="txt2img_column_batch"):
@ -1215,7 +1204,7 @@ def create_ui():
with gr.Column(elem_id='ti_gallery_container'): with gr.Column(elem_id='ti_gallery_container'):
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
ti_progress = gr.HTML(elem_id="ti_progress", value="") ti_progress = gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="") ti_outcome = gr.HTML(elem_id="ti_error", value="")
@ -1566,22 +1555,6 @@ def create_ui():
(train_interface, "Train", "ti"), (train_interface, "Train", "ti"),
] ]
css = ""
for cssfile in modules.scripts.list_files_with_name("style.css"):
if not os.path.isfile(cssfile):
continue
with open(cssfile, "r", encoding="utf8") as file:
css += file.read() + "\n"
if os.path.exists(os.path.join(data_path, "user.css")):
with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file:
css += file.read() + "\n"
if not cmd_opts.no_progressbar_hiding:
css += css_hide_progressbar
interfaces += script_callbacks.ui_tabs_callback() interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings_interface, "Settings", "settings")] interfaces += [(settings_interface, "Settings", "settings")]
@ -1592,7 +1565,7 @@ def create_ui():
for _interface, label, _ifid in interfaces: for _interface, label, _ifid in interfaces:
shared.tab_names.append(label) shared.tab_names.append(label)
with gr.Blocks(css=css, theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings", variant="compact"): with gr.Row(elem_id="quicksettings", variant="compact"):
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
component = create_setting_component(k, is_quicksettings=True) component = create_setting_component(k, is_quicksettings=True)
@ -1655,6 +1628,7 @@ def create_ui():
fn=get_settings_values, fn=get_settings_values,
inputs=[], inputs=[],
outputs=[component_dict[k] for k in component_keys], outputs=[component_dict[k] for k in component_keys],
queue=False,
) )
def modelmerger(*args): def modelmerger(*args):
@ -1731,7 +1705,7 @@ def create_ui():
if init_field is not None: if init_field is not None:
init_field(saved_value) init_field(saved_value)
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
apply_field(x, 'visible') apply_field(x, 'visible')
if type(x) == gr.Slider: if type(x) == gr.Slider:
@ -1777,25 +1751,60 @@ def create_ui():
return demo return demo
def reload_javascript(): def webpath(fn):
if fn.startswith(script_path):
web_path = os.path.relpath(fn, script_path).replace('\\', '/')
else:
web_path = os.path.abspath(fn)
return f'file={web_path}?{os.path.getmtime(fn)}'
def javascript_html():
script_js = os.path.join(script_path, "script.js") script_js = os.path.join(script_path, "script.js")
head = f'<script type="text/javascript" src="file={os.path.abspath(script_js)}?{os.path.getmtime(script_js)}"></script>\n' head = f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
inline = f"{localization.localization_js(shared.opts.localization)};" inline = f"{localization.localization_js(shared.opts.localization)};"
if cmd_opts.theme is not None: if cmd_opts.theme is not None:
inline += f"set_theme('{cmd_opts.theme}');" inline += f"set_theme('{cmd_opts.theme}');"
for script in modules.scripts.list_scripts("javascript", ".js"): for script in modules.scripts.list_scripts("javascript", ".js"):
head += f'<script type="text/javascript" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n' head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
for script in modules.scripts.list_scripts("javascript", ".mjs"): for script in modules.scripts.list_scripts("javascript", ".mjs"):
head += f'<script type="module" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n' head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
head += f'<script type="text/javascript">{inline}</script>\n' head += f'<script type="text/javascript">{inline}</script>\n'
return head
def css_html():
head = ""
def stylesheet(fn):
return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
for cssfile in modules.scripts.list_files_with_name("style.css"):
if not os.path.isfile(cssfile):
continue
head += stylesheet(cssfile)
if os.path.exists(os.path.join(data_path, "user.css")):
head += stylesheet(os.path.join(data_path, "user.css"))
return head
def reload_javascript():
js = javascript_html()
css = css_html()
def template_response(*args, **kwargs): def template_response(*args, **kwargs):
res = shared.GradioTemplateResponseOriginal(*args, **kwargs) res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
res.body = res.body.replace(b'</head>', f'{head}</head>'.encode("utf8")) res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
res.init_headers() res.init_headers()
return res return res

View File

@ -125,7 +125,7 @@ Requested path was: {f}
with gr.Column(variant='panel', elem_id=f"{tabname}_results"): with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
with gr.Group(elem_id=f"{tabname}_gallery_container"): with gr.Group(elem_id=f"{tabname}_gallery_container"):
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4)
generation_info = None generation_info = None
with gr.Column(): with gr.Column():
@ -145,8 +145,7 @@ Requested path was: {f}
) )
if tabname != "extras": if tabname != "extras":
with gr.Row(): download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
with gr.Group(): with gr.Group():
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")

View File

@ -21,7 +21,7 @@ def check_access():
assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags" assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
def apply_and_restart(disable_list, update_list): def apply_and_restart(disable_list, update_list, disable_all):
check_access() check_access()
disabled = json.loads(disable_list) disabled = json.loads(disable_list)
@ -43,6 +43,7 @@ def apply_and_restart(disable_list, update_list):
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
shared.opts.disabled_extensions = disabled shared.opts.disabled_extensions = disabled
shared.opts.disable_all_extensions = disable_all
shared.opts.save(shared.config_filename) shared.opts.save(shared.config_filename)
shared.state.interrupt() shared.state.interrupt()
@ -63,6 +64,9 @@ def check_updates(id_task, disable_list):
try: try:
ext.check_updates() ext.check_updates()
except FileNotFoundError as e:
if 'FETCH_HEAD' not in str(e):
raise
except Exception: except Exception:
print(f"Error checking updates for {ext.name}:", file=sys.stderr) print(f"Error checking updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
@ -87,6 +91,8 @@ def extension_table():
""" """
for ext in extensions.extensions: for ext in extensions.extensions:
ext.read_info_from_repo()
remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>""" remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
if ext.can_update: if ext.can_update:
@ -94,9 +100,13 @@ def extension_table():
else: else:
ext_status = ext.status ext_status = ext.status
style = ""
if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all":
style = ' style="color: var(--primary-400)"'
code += f""" code += f"""
<tr> <tr>
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td> <td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td>{remote}</td> <td>{remote}</td>
<td>{ext.version}</td> <td>{ext.version}</td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td> <td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
@ -119,7 +129,7 @@ def normalize_git_url(url):
return url return url
def install_extension_from_url(dirname, url): def install_extension_from_url(dirname, branch_name, url):
check_access() check_access()
assert url, 'No URL specified' assert url, 'No URL specified'
@ -140,10 +150,17 @@ def install_extension_from_url(dirname, url):
try: try:
shutil.rmtree(tmpdir, True) shutil.rmtree(tmpdir, True)
with git.Repo.clone_from(url, tmpdir) as repo: if branch_name == '':
repo.remote().fetch() # if no branch is specified, use the default branch
for submodule in repo.submodules: with git.Repo.clone_from(url, tmpdir) as repo:
submodule.update() repo.remote().fetch()
for submodule in repo.submodules:
submodule.update()
else:
with git.Repo.clone_from(url, tmpdir, branch=branch_name) as repo:
repo.remote().fetch()
for submodule in repo.submodules:
submodule.update()
try: try:
os.rename(tmpdir, target_dir) os.rename(tmpdir, target_dir)
except OSError as err: except OSError as err:
@ -289,16 +306,24 @@ def create_ui():
with gr.Row(elem_id="extensions_installed_top"): with gr.Row(elem_id="extensions_installed_top"):
apply = gr.Button(value="Apply and restart UI", variant="primary") apply = gr.Button(value="Apply and restart UI", variant="primary")
check = gr.Button(value="Check for updates") check = gr.Button(value="Check for updates")
extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False) extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False) extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
info = gr.HTML() html = ""
if shared.opts.disable_all_extensions != "none":
html = """
<span style="color: var(--primary-400);">
"Disable all extensions" was set, change it to "none" to load all extensions again
</span>
"""
info = gr.HTML(html)
extensions_table = gr.HTML(lambda: extension_table()) extensions_table = gr.HTML(lambda: extension_table())
apply.click( apply.click(
fn=apply_and_restart, fn=apply_and_restart,
_js="extensions_apply", _js="extensions_apply",
inputs=[extensions_disabled_list, extensions_update_list], inputs=[extensions_disabled_list, extensions_update_list, extensions_disable_all],
outputs=[], outputs=[],
) )
@ -358,13 +383,14 @@ def create_ui():
with gr.TabItem("Install from URL"): with gr.TabItem("Install from URL"):
install_url = gr.Text(label="URL for extension's git repository") install_url = gr.Text(label="URL for extension's git repository")
install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch")
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
install_button = gr.Button(value="Install", variant="primary") install_button = gr.Button(value="Install", variant="primary")
install_result = gr.HTML(elem_id="extension_install_result") install_result = gr.HTML(elem_id="extension_install_result")
install_button.click( install_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]), fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
inputs=[install_dirname, install_url], inputs=[install_dirname, install_branch, install_url],
outputs=[extensions_table, install_result], outputs=[extensions_table, install_result],
) )

View File

@ -2,8 +2,10 @@ import glob
import os.path import os.path
import urllib.parse import urllib.parse
from pathlib import Path from pathlib import Path
from PIL import PngImagePlugin
from modules import shared from modules import shared
from modules.images import read_info_from_image
import gradio as gr import gradio as gr
import json import json
import html import html
@ -252,10 +254,10 @@ def create_ui(container, button, tabname):
def toggle_visibility(is_visible): def toggle_visibility(is_visible):
is_visible = not is_visible is_visible = not is_visible
return is_visible, gr.update(visible=is_visible) return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
state_visible = gr.State(value=False) state_visible = gr.State(value=False)
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container]) button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button])
def refresh(): def refresh():
res = [] res = []
@ -290,6 +292,7 @@ def setup_ui(ui, gallery):
img_info = images[index if index >= 0 else 0] img_info = images[index if index >= 0 else 0]
image = image_from_url_text(img_info) image = image_from_url_text(img_info)
geninfo, items = read_info_from_image(image)
is_allowed = False is_allowed = False
for extra_page in ui.stored_extra_pages: for extra_page in ui.stored_extra_pages:
@ -299,7 +302,12 @@ def setup_ui(ui, gallery):
assert is_allowed, f'writing to {filename} is not allowed' assert is_allowed, f'writing to {filename} is not allowed'
image.save(filename) if geninfo:
pnginfo_data = PngImagePlugin.PngInfo()
pnginfo_data.add_text('parameters', geninfo)
image.save(filename, pnginfo=pnginfo_data)
else:
image.save(filename)
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]

View File

@ -13,7 +13,7 @@ def create_ui():
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch: with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir: with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")

View File

@ -1,10 +1,11 @@
astunparse
blendmodes blendmodes
accelerate accelerate
basicsr basicsr
fonts fonts
font-roboto font-roboto
gfpgan gfpgan
gradio==3.23 gradio==3.27
invisible-watermark invisible-watermark
numpy numpy
omegaconf omegaconf

View File

@ -1,10 +1,10 @@
blendmodes==2022 blendmodes==2022
transformers==4.25.1 transformers==4.25.1
accelerate==0.12.0 accelerate==0.18.0
basicsr==1.4.2 basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.23 gradio==3.27
numpy==1.23.3 numpy==1.23.5
Pillow==9.4.0 Pillow==9.4.0
realesrgan==0.3.0 realesrgan==0.3.0
torch torch
@ -25,6 +25,6 @@ lark==1.1.2
inflection==0.5.1 inflection==0.5.1
GitPython==3.1.30 GitPython==3.1.30
torchsde==0.2.5 torchsde==0.2.5
safetensors==0.3.0 safetensors==0.3.1
httpcore<=0.15 httpcore<=0.15
fastapi==0.94.0 fastapi==0.94.0

View File

@ -1,9 +1,40 @@
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
import ast
import copy
from modules.processing import Processed from modules.processing import Processed
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
def convertExpr2Expression(expr):
expr.lineno = 0
expr.col_offset = 0
result = ast.Expression(expr.value, lineno=0, col_offset = 0)
return result
def exec_with_return(code, module):
"""
like exec() but can return values
https://stackoverflow.com/a/52361938/5862977
"""
code_ast = ast.parse(code)
init_ast = copy.deepcopy(code_ast)
init_ast.body = code_ast.body[:-1]
last_ast = copy.deepcopy(code_ast)
last_ast.body = code_ast.body[-1:]
exec(compile(init_ast, "<ast>", "exec"), module.__dict__)
if type(last_ast.body[0]) == ast.Expr:
return eval(compile(convertExpr2Expression(last_ast.body[0]), "<ast>", "eval"), module.__dict__)
else:
exec(compile(last_ast, "<ast>", "exec"), module.__dict__)
class Script(scripts.Script): class Script(scripts.Script):
def title(self): def title(self):
@ -13,12 +44,23 @@ class Script(scripts.Script):
return cmd_opts.allow_code return cmd_opts.allow_code
def ui(self, is_img2img): def ui(self, is_img2img):
code = gr.Textbox(label="Python code", lines=1, elem_id=self.elem_id("code")) example = """from modules.processing import process_images
return [code] p.width = 768
p.height = 768
p.batch_size = 2
p.steps = 10
return process_images(p)
"""
def run(self, p, code): code = gr.Code(value=example, language="python", label="Python code", elem_id=self.elem_id("code"))
indent_level = gr.Number(label='Indent level', value=2, precision=0, elem_id=self.elem_id("indent_level"))
return [code, indent_level]
def run(self, p, code, indent_level):
assert cmd_opts.allow_code, '--allow-code option must be enabled' assert cmd_opts.allow_code, '--allow-code option must be enabled'
display_result_data = [[], -1, ""] display_result_data = [[], -1, ""]
@ -29,13 +71,20 @@ class Script(scripts.Script):
display_result_data[2] = i display_result_data[2] = i
from types import ModuleType from types import ModuleType
compiled = compile(code, '', 'exec')
module = ModuleType("testmodule") module = ModuleType("testmodule")
module.__dict__.update(globals()) module.__dict__.update(globals())
module.p = p module.p = p
module.display = display module.display = display
exec(compiled, module.__dict__)
indent = " " * indent_level
indented = code.replace('\n', '\n' + indent)
body = f"""def __webuitemp__():
{indent}{indented}
__webuitemp__()"""
result = exec_with_return(body, module)
if isinstance(result, Processed):
return result
return Processed(p, *display_result_data) return Processed(p, *display_result_data)

View File

@ -54,15 +54,12 @@ class Script(scripts.Script):
return strength return strength
progress = loop / (loops - 1) progress = loop / (loops - 1)
match denoising_curve: if denoising_curve == "Aggressive":
case "Aggressive": strength = math.sin((progress) * math.pi * 0.5)
strength = math.sin((progress) * math.pi * 0.5) elif denoising_curve == "Lazy":
strength = 1 - math.cos((progress) * math.pi * 0.5)
case "Lazy": else:
strength = 1 - math.cos((progress) * math.pi * 0.5) strength = progress
case _:
strength = progress
change = (final_denoising_strength - initial_denoising_strength) * strength change = (final_denoising_strength - initial_denoising_strength) * strength
return initial_denoising_strength + change return initial_denoising_strength + change

View File

@ -4,8 +4,8 @@ import numpy as np
from modules import scripts_postprocessing, shared from modules import scripts_postprocessing, shared
import gradio as gr import gradio as gr
from modules.ui_components import FormRow from modules.ui_components import FormRow, ToolButton
from modules.ui import switch_values_symbol
upscale_cache = {} upscale_cache = {}
@ -25,9 +25,12 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
with FormRow(): with FormRow():
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") with gr.Column(elem_id="upscaling_column_size", scale=4):
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w")
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h")
with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"):
upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn")
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
with FormRow(): with FormRow():
extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
@ -36,6 +39,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False)
tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])

View File

@ -374,16 +374,19 @@ class Script(scripts.Script):
with gr.Row(): with gr.Row():
x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
x_values_dropdown = gr.Dropdown(label="X values",visible=False,multiselect=True,interactive=True)
fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False) fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
with gr.Row(): with gr.Row():
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
y_values_dropdown = gr.Dropdown(label="Y values",visible=False,multiselect=True,interactive=True)
fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False) fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
with gr.Row(): with gr.Row():
z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type")) z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values")) z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
z_values_dropdown = gr.Dropdown(label="Z values",visible=False,multiselect=True,interactive=True)
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False) fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
with gr.Row(variant="compact", elem_id="axis_options"): with gr.Row(variant="compact", elem_id="axis_options"):
@ -401,54 +404,74 @@ class Script(scripts.Script):
swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button") swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button") swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
def swap_axes(axis1_type, axis1_values, axis2_type, axis2_values): def swap_axes(axis1_type, axis1_values, axis1_values_dropdown, axis2_type, axis2_values, axis2_values_dropdown):
return self.current_axis_options[axis2_type].label, axis2_values, self.current_axis_options[axis1_type].label, axis1_values return self.current_axis_options[axis2_type].label, axis2_values, axis2_values_dropdown, self.current_axis_options[axis1_type].label, axis1_values, axis1_values_dropdown
xy_swap_args = [x_type, x_values, y_type, y_values] xy_swap_args = [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown]
swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args) swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)
yz_swap_args = [y_type, y_values, z_type, z_values] yz_swap_args = [y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown]
swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args) swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)
xz_swap_args = [x_type, x_values, z_type, z_values] xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown]
swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args) swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
def fill(x_type): def fill(x_type):
axis = self.current_axis_options[x_type] axis = self.current_axis_options[x_type]
return ", ".join(axis.choices()) if axis.choices else gr.update() return axis.choices() if axis.choices else gr.update()
fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values]) fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values_dropdown])
fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values]) fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values_dropdown])
fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values]) fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values_dropdown])
def select_axis(x_type): def select_axis(axis_type,axis_values_dropdown):
return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None) choices = self.current_axis_options[axis_type].choices
has_choices = choices is not None
current_values = axis_values_dropdown
if has_choices:
choices = choices()
if isinstance(current_values,str):
current_values = current_values.split(",")
current_values = list(filter(lambda x: x in choices, current_values))
return gr.Button.update(visible=has_choices),gr.Textbox.update(visible=not has_choices),gr.update(choices=choices if has_choices else None,visible=has_choices,value=current_values)
x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button]) x_type.change(fn=select_axis, inputs=[x_type,x_values_dropdown], outputs=[fill_x_button,x_values,x_values_dropdown])
y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) y_type.change(fn=select_axis, inputs=[y_type,y_values_dropdown], outputs=[fill_y_button,y_values,y_values_dropdown])
z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button]) z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown])
def get_dropdown_update_from_params(axis,params):
val_key = axis + " Values"
vals = params.get(val_key,"")
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
return gr.update(value = valslist)
self.infotext_fields = ( self.infotext_fields = (
(x_type, "X Type"), (x_type, "X Type"),
(x_values, "X Values"), (x_values, "X Values"),
(x_values_dropdown, lambda params:get_dropdown_update_from_params("X",params)),
(y_type, "Y Type"), (y_type, "Y Type"),
(y_values, "Y Values"), (y_values, "Y Values"),
(y_values_dropdown, lambda params:get_dropdown_update_from_params("Y",params)),
(z_type, "Z Type"), (z_type, "Z Type"),
(z_values, "Z Values"), (z_values, "Z Values"),
(z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)),
) )
return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size] return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size): def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
if not no_fixed_seeds: if not no_fixed_seeds:
modules.processing.fix_seed(p) modules.processing.fix_seed(p)
if not opts.return_grid: if not opts.return_grid:
p.batch_size = 1 p.batch_size = 1
def process_axis(opt, vals): def process_axis(opt, vals, vals_dropdown):
if opt.label == 'Nothing': if opt.label == 'Nothing':
return [0] return [0]
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x] if opt.choices is not None:
valslist = vals_dropdown
else:
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
if opt.type == int: if opt.type == int:
valslist_ext = [] valslist_ext = []
@ -506,13 +529,19 @@ class Script(scripts.Script):
return valslist return valslist
x_opt = self.current_axis_options[x_type] x_opt = self.current_axis_options[x_type]
xs = process_axis(x_opt, x_values) if x_opt.choices is not None:
x_values = ",".join(x_values_dropdown)
xs = process_axis(x_opt, x_values, x_values_dropdown)
y_opt = self.current_axis_options[y_type] y_opt = self.current_axis_options[y_type]
ys = process_axis(y_opt, y_values) if y_opt.choices is not None:
y_values = ",".join(y_values_dropdown)
ys = process_axis(y_opt, y_values, y_values_dropdown)
z_opt = self.current_axis_options[z_type] z_opt = self.current_axis_options[z_type]
zs = process_axis(z_opt, z_values) if z_opt.choices is not None:
z_values = ",".join(z_values_dropdown)
zs = process_axis(z_opt, z_values, z_values_dropdown)
# this could be moved to common code, but unlikely to be ever triggered anywhere else # this could be moved to common code, but unlikely to be ever triggered anywhere else
Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes

View File

@ -7,7 +7,7 @@
--block-background-fill: transparent; --block-background-fill: transparent;
} }
.block.padded{ .block.padded:not(.gradio-accordion) {
padding: 0 !important; padding: 0 !important;
} }
@ -54,10 +54,6 @@ div.compact{
gap: 1em; gap: 1em;
} }
.gradio-dropdown ul.options{
z-index: 3000;
}
.gradio-dropdown label span:not(.has-info), .gradio-dropdown label span:not(.has-info),
.gradio-textbox label span:not(.has-info), .gradio-textbox label span:not(.has-info),
.gradio-number label span:not(.has-info) .gradio-number label span:not(.has-info)
@ -65,11 +61,30 @@ div.compact{
margin-bottom: 0; margin-bottom: 0;
} }
.gradio-dropdown ul.options{
z-index: 3000;
min-width: fit-content;
max-width: inherit;
white-space: nowrap;
}
.gradio-dropdown ul.options li.item {
padding: 0.05em 0;
}
.gradio-dropdown ul.options li.item.selected {
background-color: var(--neutral-100);
}
.dark .gradio-dropdown ul.options li.item.selected {
background-color: var(--neutral-900);
}
.gradio-dropdown div.wrap.wrap.wrap.wrap{ .gradio-dropdown div.wrap.wrap.wrap.wrap{
box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.05); box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.05);
} }
.gradio-dropdown .wrap-inner.wrap-inner.wrap-inner{ .gradio-dropdown:not(.multiselect) .wrap-inner.wrap-inner.wrap-inner{
flex-wrap: unset; flex-wrap: unset;
} }
@ -123,6 +138,18 @@ div.gradio-html.min{
border-radius: 0.5em; border-radius: 0.5em;
} }
.gradio-button.secondary-down{
background: var(--button-secondary-background-fill);
color: var(--button-secondary-text-color);
}
.gradio-button.secondary-down, .gradio-button.secondary-down:hover{
box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset;
}
.gradio-button.secondary-down:hover{
background: var(--button-secondary-background-fill-hover);
color: var(--button-secondary-text-color-hover);
}
.checkboxes-row{ .checkboxes-row{
margin-bottom: 0.5em; margin-bottom: 0.5em;
margin-left: 0em; margin-left: 0em;
@ -285,12 +312,23 @@ div.dimensions-tools{
align-content: center; align-content: center;
} }
div#extras_scale_to_tab div.form{
flex-direction: row;
}
#mode_img2img .gradio-image > div.fixed-height, #mode_img2img .gradio-image > div.fixed-height img{ #mode_img2img .gradio-image > div.fixed-height, #mode_img2img .gradio-image > div.fixed-height img{
height: 480px !important; height: 480px !important;
max-height: 480px !important; max-height: 480px !important;
min-height: 480px !important; min-height: 480px !important;
} }
#img2img_sketch, #img2maskimg, #inpaint_sketch {
overflow: overlay !important;
resize: auto;
background: var(--panel-background-fill);
z-index: 5;
}
.image-buttons button{ .image-buttons button{
min-width: auto; min-width: auto;
} }
@ -302,6 +340,7 @@ div.dimensions-tools{
/* settings */ /* settings */
#quicksettings { #quicksettings {
width: fit-content; width: fit-content;
align-items: end;
} }
#quicksettings > div, #quicksettings > fieldset{ #quicksettings > div, #quicksettings > fieldset{
@ -507,6 +546,17 @@ div.dimensions-tools{
background-color: rgba(0, 0, 0, 0.8); background-color: rgba(0, 0, 0, 0.8);
} }
#imageARPreview {
position: absolute;
top: 0px;
left: 0px;
border: 2px solid red;
background: rgba(255, 0, 0, 0.3);
z-index: 900;
pointer-events: none;
display: none;
}
/* context menu (ie for the generate button) */ /* context menu (ie for the generate button) */
#context-menu{ #context-menu{

View File

@ -11,7 +11,7 @@ fi
export install_dir="$HOME" export install_dir="$HOME"
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1" export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118"
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git" export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
export PYTORCH_ENABLE_MPS_FALLBACK=1 export PYTORCH_ENABLE_MPS_FALLBACK=1

View File

@ -20,6 +20,9 @@ startup_timer = timer.Timer()
import torch import torch
import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
startup_timer.record("import torch") startup_timer.record("import torch")
import gradio import gradio
@ -67,11 +70,51 @@ else:
server_name = "0.0.0.0" if cmd_opts.listen else None server_name = "0.0.0.0" if cmd_opts.listen else None
def fix_asyncio_event_loop_policy():
"""
The default `asyncio` event loop policy only automatically creates
event loops in the main threads. Other threads must create event
loops explicitly or `asyncio.get_event_loop` (and therefore
`.IOLoop.current`) will fail. Installing this policy allows event
loops to be created automatically on any thread, matching the
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
"""
import asyncio
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
# "Any thread" and "selector" should be orthogonal, but there's not a clean
# interface for composing policies so pick the right base.
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
else:
_BasePolicy = asyncio.DefaultEventLoopPolicy
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
"""Event loop policy that allows loop creation on any thread.
Usage::
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
"""
def get_event_loop(self) -> asyncio.AbstractEventLoop:
try:
return super().get_event_loop()
except (RuntimeError, AssertionError):
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
# and changed to a RuntimeError in 3.4.3.
# "There is no current event loop in thread %r"
loop = self.new_event_loop()
self.set_event_loop(loop)
return loop
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
def check_versions(): def check_versions():
if shared.cmd_opts.skip_version_check: if shared.cmd_opts.skip_version_check:
return return
expected_torch_version = "1.13.1" expected_torch_version = "2.0.0"
if version.parse(torch.__version__) < version.parse(expected_torch_version): if version.parse(torch.__version__) < version.parse(expected_torch_version):
errors.print_error_explanation(f""" errors.print_error_explanation(f"""
@ -84,7 +127,7 @@ there are reports of issues with training tab on the latest version.
Use --skip-version-check commandline argument to disable this check. Use --skip-version-check commandline argument to disable this check.
""".strip()) """.strip())
expected_xformers_version = "0.0.16rc425" expected_xformers_version = "0.0.17"
if shared.xformers_available: if shared.xformers_available:
import xformers import xformers
@ -99,6 +142,8 @@ Use --skip-version-check commandline argument to disable this check.
def initialize(): def initialize():
fix_asyncio_event_loop_policy()
check_versions() check_versions()
extensions.list_extensions() extensions.list_extensions()
@ -126,9 +171,6 @@ def initialize():
modules.scripts.load_scripts() modules.scripts.load_scripts()
startup_timer.record("load scripts") startup_timer.record("load scripts")
modelloader.load_upscalers()
startup_timer.record("load upscalers")
modules.sd_vae.refresh_vae_list() modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE") startup_timer.record("refresh VAE")
@ -266,9 +308,6 @@ def webui():
inbrowser=cmd_opts.autolaunch, inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True prevent_thread_lock=True
) )
for dep in shared.demo.dependencies:
dep['show_progress'] = False # disable gradio css animation on component update
# after initial launch, disable --autolaunch for subsequent restarts # after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False cmd_opts.autolaunch = False