Merge branch 'master' into textual__inversion

This commit is contained in:
alg-wiki 2022-10-11 03:35:28 +08:00 committed by GitHub
commit f0ab972f85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1194 additions and 48 deletions

View File

@ -0,0 +1,28 @@
# Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request!
If you have a large change, pay special attention to this paragraph:
> Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature.
Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on.
**Describe what this pull request is trying to achieve.**
A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code.
**Additional notes and description of your changes**
More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of.
**Environment this was tested in**
List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
- OS: [e.g. Windows, Linux]
- Browser [e.g. chrome, safari]
- Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
**Screenshots or videos of your changes**
If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made.
This is **required** for anything that touches the user interface.

View File

@ -79,6 +79,7 @@ titles = {
"Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition", "Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
"Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.", "Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
} }

View File

@ -127,7 +127,7 @@ def prepare_enviroment():
if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"): if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"):
if platform.system() == "Windows": if platform.system() == "Windows":
run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers") run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/c/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
elif platform.system() == "Linux": elif platform.system() == "Linux":
run_pip("install xformers", "xformers") run_pip("install xformers", "xformers")

View File

@ -36,6 +36,7 @@ errors.run(enable_tf32, "Enabling TF32")
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
dtype = torch.float16 dtype = torch.float16
dtype_vae = torch.float16
def randn(seed, shape): def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
@ -59,9 +60,12 @@ def randn_without_seed(shape):
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
def autocast(): def autocast(disable=False):
from modules import shared from modules import shared
if disable:
return contextlib.nullcontext()
if dtype == torch.float32 or shared.cmd_opts.precision == "full": if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext() return contextlib.nullcontext()

View File

@ -207,7 +207,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
# enables the generation of additional tensors with noise that the sampler will use during its processing. # enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101]. # produce the same images as with two batches [100], [101].
if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds: if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else: else:
sampler_noises = None sampler_noises = None
@ -247,6 +247,9 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
if sampler_noises is not None: if sampler_noises is not None:
cnt = p.sampler.number_of_needed_noises(p) cnt = p.sampler.number_of_needed_noises(p)
if opts.eta_noise_seed_delta > 0:
torch.manual_seed(seed + opts.eta_noise_seed_delta)
for j in range(cnt): for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape))) sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
@ -259,6 +262,13 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x return x
def decode_first_stage(model, x):
with devices.autocast(disable=x.dtype == devices.dtype_vae):
x = model.decode_first_stage(x)
return x
def get_fixed_seed(seed): def get_fixed_seed(seed):
if seed is None or seed == '' or seed == -1: if seed is None or seed == '' or seed == -1:
return int(random.randrange(4294967294)) return int(random.randrange(4294967294))
@ -294,6 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Denoising strength": getattr(p, 'denoising_strength', None), "Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip, "Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
} }
generation_params.update(p.extra_generation_params) generation_params.update(p.extra_generation_params)
@ -398,9 +409,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
# use the image collected previously in sampler loop # use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent samples_ddim = shared.state.current_latent
samples_ddim = samples_ddim.to(devices.dtype) samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
del samples_ddim del samples_ddim
@ -533,7 +543,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.scale_latent: if self.scale_latent:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
else: else:
decoded_samples = self.sd_model.decode_first_stage(samples) decoded_samples = decode_first_stage(self.sd_model, samples)
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear") decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")

View File

@ -23,7 +23,7 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6)): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)):
print("Applying xformers cross attention optimization.") print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
@ -43,10 +43,7 @@ def undo_optimizations():
def get_target_prompt_token_count(token_count): def get_target_prompt_token_count(token_count):
if token_count < 75: return math.ceil(max(token_count, 1) / 75) * 75
return 75
return math.ceil(token_count / 10) * 10
class StableDiffusionModelHijack: class StableDiffusionModelHijack:
@ -127,7 +124,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.token_mults[ident] = mult self.token_mults[ident] = mult
def tokenize_line(self, line, used_custom_terms, hijack_comments): def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id id_end = self.wrapped.tokenizer.eos_token_id
if opts.enable_emphasis: if opts.enable_emphasis:
@ -154,7 +150,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
i += 1 i += 1
else: else:
emb_len = int(embedding.vec.shape[0]) emb_len = int(embedding.vec.shape[0])
fixes.append((len(remade_tokens), embedding)) iteration = len(remade_tokens) // 75
if (len(remade_tokens) + emb_len) // 75 != iteration:
rem = (75 * (iteration + 1) - len(remade_tokens))
remade_tokens += [id_end] * rem
multipliers += [1.0] * rem
iteration += 1
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
remade_tokens += [0] * emb_len remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum())) used_custom_terms.append((embedding.name, embedding.checksum()))
@ -162,10 +164,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
token_count = len(remade_tokens) token_count = len(remade_tokens)
prompt_target_length = get_target_prompt_token_count(token_count) prompt_target_length = get_target_prompt_token_count(token_count)
tokens_to_add = prompt_target_length - len(remade_tokens) + 1 tokens_to_add = prompt_target_length - len(remade_tokens)
remade_tokens = [id_start] + remade_tokens + [id_end] * tokens_to_add remade_tokens = remade_tokens + [id_end] * tokens_to_add
multipliers = [1.0] + multipliers + [1.0] * tokens_to_add multipliers = multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count return remade_tokens, fixes, multipliers, token_count
@ -262,27 +264,53 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text): def forward(self, text):
use_old = opts.use_old_emphasis_implementation
if opts.use_old_emphasis_implementation: if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else: else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.fixes = hijack_fixes
self.hijack.comments += hijack_comments self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0: if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
target_token_count = get_target_prompt_token_count(token_count) + 2 if use_old:
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
position_ids_array = [min(x, 75) for x in range(target_token_count-1)] + [76] z = None
position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1)) i = 0
while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]
remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] self.hijack.fixes = []
tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) for unfiltered in hijack_fixes:
fixes = []
for fix in unfiltered:
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers])
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
return z
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=-opts.CLIP_stop_at_last_layers)
if opts.CLIP_stop_at_last_layers > 1: if opts.CLIP_stop_at_last_layers > 1:
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
z = self.wrapped.transformer.text_model.final_layer_norm(z) z = self.wrapped.transformer.text_model.final_layer_norm(z)
@ -290,7 +318,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
z = outputs.last_hidden_state z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device) batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
original_mean = z.mean() original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)

View File

@ -13,8 +13,6 @@ from modules import shared
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try: try:
import xformers.ops import xformers.ops
import functorch
xformers._is_functorch_available = True
shared.xformers_available = True shared.xformers_available = True
except Exception: except Exception:
print("Cannot import xformers", file=sys.stderr) print("Cannot import xformers", file=sys.stderr)

View File

@ -149,8 +149,13 @@ def load_model_weights(model, checkpoint_info):
model.half() model.half()
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
vae_file = shared.cmd_opts.vae_path
if os.path.exists(vae_file): if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}") print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location="cpu") vae_ckpt = torch.load(vae_file, map_location="cpu")
@ -158,6 +163,8 @@ def load_model_weights(model, checkpoint_info):
model.first_stage_model.load_state_dict(vae_dict) model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae)
model.sd_model_hash = sd_model_hash model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info model.sd_checkpoint_info = checkpoint_info

View File

@ -7,7 +7,7 @@ import inspect
import k_diffusion.sampling import k_diffusion.sampling
import ldm.models.diffusion.ddim import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms import ldm.models.diffusion.plms
from modules import prompt_parser from modules import prompt_parser, devices, processing
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
@ -83,7 +83,7 @@ def setup_img2img_steps(p, steps=None):
def sample_to_image(samples): def sample_to_image(samples):
x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0] x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)

View File

@ -25,6 +25,7 @@ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to director
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
@ -65,6 +66,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
@ -259,6 +261,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
})) }))

View File

@ -10,6 +10,7 @@ from tqdm import tqdm
from modules import modelloader from modules import modelloader
from modules.shared import cmd_opts, opts, device from modules.shared import cmd_opts, opts, device
from modules.swinir_model_arch import SwinIR as net from modules.swinir_model_arch import SwinIR as net
from modules.swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
precision_scope = ( precision_scope = (
@ -57,6 +58,22 @@ class UpscalerSwinIR(Upscaler):
filename = path filename = path
if filename is None or not os.path.exists(filename): if filename is None or not os.path.exists(filename):
return None return None
if filename.endswith(".v2.pth"):
model = net2(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="1conv",
)
params = None
else:
model = net( model = net(
upscale=scale, upscale=scale,
in_chans=3, in_chans=3,
@ -70,9 +87,13 @@ class UpscalerSwinIR(Upscaler):
upsampler="nearest+conv", upsampler="nearest+conv",
resi_connection="3conv", resi_connection="3conv",
) )
params = "params_ema"
pretrained_model = torch.load(filename) pretrained_model = torch.load(filename)
model.load_state_dict(pretrained_model["params_ema"], strict=True) if params is not None:
model.load_state_dict(pretrained_model[params], strict=True)
else:
model.load_state_dict(pretrained_model, strict=True)
if not cmd_opts.no_half: if not cmd_opts.no_half:
model = model.half() model = model.half()
return model return model

File diff suppressed because it is too large Load Diff

View File

@ -202,6 +202,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
return embedding, filename return embedding, filename
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root) if os.path.splitext(file_path.casefold())[1] in extns]) tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root) if os.path.splitext(file_path.casefold())[1] in extns])
epoch_len = (tr_img_len * num_repeats) + tr_img_len epoch_len = (tr_img_len * num_repeats) + tr_img_len
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)

View File

@ -524,7 +524,7 @@ def create_ui(wrap_gradio_gpu_call):
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
with gr.Row(): with gr.Row():
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1) batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
@ -710,7 +710,7 @@ def create_ui(wrap_gradio_gpu_call):
tiling = gr.Checkbox(label='Tiling', value=False) tiling = gr.Checkbox(label='Tiling', value=False)
with gr.Row(): with gr.Row():
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1) batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
with gr.Group(): with gr.Group():
@ -961,7 +961,7 @@ def create_ui(wrap_gradio_gpu_call):
extras_send_to_inpaint.click( extras_send_to_inpaint.click(
fn=lambda x: image_from_url_text(x), fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery_img2img", _js="extract_image_from_gallery_inpaint",
inputs=[result_images], inputs=[result_images],
outputs=[init_img_with_mask], outputs=[init_img_with_mask],
) )

View File

@ -23,4 +23,3 @@ resize-right
torchdiffeq torchdiffeq
kornia kornia
lark lark
functorch

View File

@ -22,4 +22,3 @@ resize-right==0.0.2
torchdiffeq==0.2.3 torchdiffeq==0.2.3
kornia==0.6.7 kornia==0.6.7
lark==1.1.2 lark==1.1.2
functorch==0.2.1

View File

@ -40,6 +40,22 @@ document.addEventListener("DOMContentLoaded", function() {
mutationObserver.observe( gradioApp(), { childList:true, subtree:true }) mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
}); });
/**
* Add a ctrl+enter as a shortcut to start a generation
*/
document.addEventListener('keydown', function(e) {
var handled = false;
if (e.key !== undefined) {
if((e.key == "Enter" && (e.metaKey || e.ctrlKey))) handled = true;
} else if (e.keyCode !== undefined) {
if((e.keyCode == 13 && (e.metaKey || e.ctrlKey))) handled = true;
}
if (handled) {
gradioApp().querySelector("#txt2img_generate").click();
e.preventDefault();
}
})
/** /**
* checks that a UI element is not in another hidden element or tab content * checks that a UI element is not in another hidden element or tab content
*/ */

View File

@ -205,7 +205,10 @@ class Script(scripts.Script):
if not no_fixed_seeds: if not no_fixed_seeds:
modules.processing.fix_seed(p) modules.processing.fix_seed(p)
if not opts.return_grid:
p.batch_size = 1 p.batch_size = 1
CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
def process_axis(opt, vals): def process_axis(opt, vals):

View File

@ -1,3 +1,7 @@
.container {
max-width: 100%;
}
.output-html p {margin: 0 0.5em;} .output-html p {margin: 0 0.5em;}
.row > *, .row > *,
@ -463,3 +467,10 @@ input[type="range"]{
max-width: 32em; max-width: 32em;
padding: 0; padding: 0;
} }
canvas[key="mask"] {
z-index: 12 !important;
filter: invert();
mix-blend-mode: multiply;
pointer-events: none;
}