From e57b5f7c5560c49fbaf05e6bea326478222cb3e6 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Wed, 25 Jan 2023 22:36:14 -0500 Subject: [PATCH 01/46] re_param captures quotes with commas properly and removes unnecessary regex --- modules/generation_parameters_copypaste.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 46e12dc6..13d0874d 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -11,9 +11,8 @@ from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image -re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' +re_param_code = r'\s*([\w ]+):\s*(\"[^\"]*\"|[^,]+)' re_param = re.compile(re_param_code) -re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") type_of_gr_update = type(gr.update()) @@ -243,7 +242,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model done_with_prompt = False *lines, lastline = x.strip().split("\n") - if not re_params.match(lastline): + if not re_param.match(lastline): lines.append(lastline) lastline = '' @@ -262,6 +261,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model res["Negative prompt"] = negative_prompt for k, v in re_param.findall(lastline): + v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v m = re_imagesize.match(v) if m is not None: res[k+"-1"] = m.group(1) From 4d634dc592ffdbd4ebb2f1acfb9a63f5e26e4deb Mon Sep 17 00:00:00 2001 From: EllangoK Date: Thu, 26 Jan 2023 00:18:41 -0500 Subject: [PATCH 02/46] adds components to infotext_fields allows for loading script params --- modules/scripts.py | 14 ++++++++++++++ scripts/xyz_grid.py | 10 ++++++++++ 2 files changed, 24 insertions(+) diff --git a/modules/scripts.py b/modules/scripts.py index 03907a63..eefdfdd4 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -330,6 +330,20 @@ class ScriptRunner: outputs=[script.group for script in self.selectable_scripts] ) + self.script_load_ctr = 0 + def onload_script_visibility(params): + title = params.get('Script', None) + if title: + title_index = self.titles.index(title) + visibility = title_index == self.script_load_ctr + self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles) + return gr.update(visible=visibility) + else: + return gr.update(visible=False) + + self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) ) + self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] ) + return inputs def run(self, p: StableDiffusionProcessing, *args): diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 828c2d12..f3378686 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -383,6 +383,15 @@ class Script(scripts.Script): y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button]) + self.infotext_fields = ( + (x_type, "X Type"), + (x_values, "X Values"), + (y_type, "Y Type"), + (y_values, "Y Values"), + (z_type, "Z Type"), + (z_values, "Z Values"), + ) + return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds] def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds): @@ -541,6 +550,7 @@ class Script(scripts.Script): if grid_infotext[0] is None: pc.extra_generation_params = copy(pc.extra_generation_params) + pc.extra_generation_params['Script'] = self.title() if x_opt.label != 'Nothing': pc.extra_generation_params["X Type"] = x_opt.label From c4b9b07db6272768428fa8efeb7d7a9f22eca0b1 Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 26 Jan 2023 09:00:15 -0500 Subject: [PATCH 03/46] Fix embeddings dtype mismatch --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f9652d21..531790f3 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -171,7 +171,7 @@ class EmbeddingsWithFixes(torch.nn.Module): vecs = [] for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: - emb = embedding.vec + emb = embedding.vec.to(devices.dtype_unet) if devices.unet_needs_upcast else embedding.vec emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) From a43fafb481feb3ef369d1963412de4e7b320fc34 Mon Sep 17 00:00:00 2001 From: ItsOlegDm Date: Thu, 26 Jan 2023 23:25:48 +0200 Subject: [PATCH 04/46] css fixes --- style.css | 45 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/style.css b/style.css index dd914104..1e90b105 100644 --- a/style.css +++ b/style.css @@ -74,7 +74,12 @@ #txt2img_gallery img, #img2img_gallery img{ object-fit: scale-down; } - +#txt2img_actions_column, #img2img_actions_column { + margin: 0.35rem 0.75rem 0.35rem 0; +} +#script_list { + padding: .625rem .75rem 0 .625rem; +} .justify-center.overflow-x-scroll { justify-content: left; } @@ -126,10 +131,12 @@ #txt2img_actions_column, #img2img_actions_column{ gap: 0; + margin-right: .75rem; } #txt2img_tools, #img2img_tools{ gap: 0.4em; + justify-content: center; } #interrogate_col{ @@ -155,7 +162,9 @@ #txt2img_styles_row > button, #img2img_styles_row > button{ margin: 0; } - +#txt2img_styles_row { + margin-top: 0.3em; +} #txt2img_styles, #img2img_styles{ padding: 0; } @@ -311,11 +320,11 @@ input[type="range"]{ .min-h-\[6rem\] { min-height: unset !important; } .progressDiv{ - position: absolute; + position: relative; height: 20px; - top: -20px; background: #b4c0cc; border-radius: 3px !important; + margin-bottom: -3px; } .dark .progressDiv{ @@ -535,7 +544,7 @@ input[type="range"]{ } #quicksettings { - gap: 0.4em; + width: fit-content; } #quicksettings > div, #quicksettings > fieldset{ @@ -545,6 +554,7 @@ input[type="range"]{ border: none; box-shadow: none; background: none; + margin-right: 10px; } #quicksettings > div > div > div > label > span { @@ -567,7 +577,7 @@ canvas[key="mask"] { right: 0.5em; top: -0.6em; z-index: 400; - width: 8em; + width: 6em; } #quicksettings .gr-box > div > div > input.gr-text-input { top: -1.12em; @@ -665,11 +675,27 @@ canvas[key="mask"] { #quicksettings .gr-button-tool{ margin: 0; + border-color: unset; + background-color: unset; } - +#modelmerger_interp_description>p { + margin: 0!important; + text-align: center; +} +#modelmerger_interp_description { + margin: 0.35rem 0.75rem 1.23rem; +} #img2img_settings > div.gr-form, #txt2img_settings > div.gr-form { padding-top: 0.9em; + padding-bottom: 0.9em; +} +#txt2img_settings { + padding-top: 1.16em; + padding-bottom: 0.9em; +} +#img2img_settings { + padding-bottom: 0.9em; } #img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form, #train_tabs div.gr-form .gr-form{ @@ -741,6 +767,8 @@ footer { .dark .gr-compact{ background-color: rgb(31 41 55 / var(--tw-bg-opacity)); + align-items: center; + margin-left: 0; } .gr-compact{ @@ -925,3 +953,6 @@ footer { color: red; } +[id*='_prompt_container'] > div { + margin: 0!important; +} From ada17dbd7c4c68a4e559848d2e6f2a7799722806 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 27 Jan 2023 10:19:43 -0500 Subject: [PATCH 05/46] Refactor conditional casting, fix upscalers --- modules/devices.py | 8 ++++++++ modules/processing.py | 15 ++++++++------- modules/realesrgan_model.py | 2 +- modules/sd_hijack.py | 2 +- modules/sd_hijack_unet.py | 8 +++++++- 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 6b36622c..0100e4af 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -83,6 +83,14 @@ dtype_unet = torch.float16 unet_needs_upcast = False +def cond_cast_unet(input): + return input.to(dtype_unet) if unet_needs_upcast else input + + +def cond_cast_float(input): + return input.float() if unet_needs_upcast else input + + def randn(seed, shape): torch.manual_seed(seed) if device.type == 'mps': diff --git a/modules/processing.py b/modules/processing.py index 92894d67..a397702b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -172,8 +172,7 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image)) - conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), size=conditioning_image.shape[2:], @@ -217,7 +216,7 @@ class StableDiffusionProcessing: ) # Encode the new masked image using first stage of network. - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) # Create the concatenated conditioning tensor to be fed to `c_concat` conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) @@ -228,16 +227,18 @@ class StableDiffusionProcessing: return image_conditioning def img2img_image_conditioning(self, source_image, latent_image, image_mask=None): + source_image = devices.cond_cast_float(source_image) + # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely # identify itself with a field common to all models. The conditioning_key is also hybrid. if isinstance(self.sd_model, LatentDepth2ImageDiffusion): - return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image) + return self.depth2img_image_conditioning(source_image) if self.sd_model.cond_stage_key == "edit": return self.edit_image_conditioning(source_image) if self.sampler.conditioning_key in {'hybrid', 'concat'}: - return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask) + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) @@ -417,7 +418,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see def decode_first_stage(model, x): with devices.autocast(disable=x.dtype == devices.dtype_vae): - x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x) + x = model.decode_first_stage(x) return x @@ -1001,7 +1002,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images) image = 2. * image - 1. - image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None) + image = image.to(shared.device) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 47f70251..aad4a629 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -46,7 +46,7 @@ class UpscalerRealESRGAN(Upscaler): scale=info.scale, model_path=info.local_data_path, model=info.model(), - half=not cmd_opts.no_half, + half=not cmd_opts.no_half and not cmd_opts.upcast_sampling, tile=opts.ESRGAN_tile, tile_pad=opts.ESRGAN_tile_overlap, ) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 531790f3..8fc91882 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -171,7 +171,7 @@ class EmbeddingsWithFixes(torch.nn.Module): vecs = [] for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: - emb = embedding.vec.to(devices.dtype_unet) if devices.unet_needs_upcast else embedding.vec + emb = devices.cond_cast_unet(embedding.vec) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index a6ee577c..45cf2b18 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -55,8 +55,14 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module): unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) -CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) if version.parse(torch.__version__) <= version.parse("1.13.1"): CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) + +first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 +first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) From 02b8b957d763d0fc29551d13d8a2005615e8ce7a Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 28 Jan 2023 00:16:22 -0500 Subject: [PATCH 06/46] Add --no-half-vae to default macOS arguments Apparently the version of PyTorch macOS users are currently at doesn't always handle half precision VAEs correctly. We will probably want to update the default PyTorch version to 2.0 when it comes out which should fix that, and at this point nightly builds of PyTorch 2.0 are going to be recommended for most Mac users. Unfortunately someone has already reported that their M2 Mac doesn't work with the nightly PyTorch 2.0 build currently, so we can add --no-half-vae for now and give users that can install nightly PyTorch 2.0 builds a webui-user.sh configuration that overrides the default. --- webui-macos-env.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui-macos-env.sh b/webui-macos-env.sh index fa187dd1..37cac4fb 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -10,7 +10,7 @@ then fi export install_dir="$HOME" -export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --use-cpu interrogate" +export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1" export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git" export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" From f9edd578e9e29d160e6d56038bb368dc49895d64 Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 28 Jan 2023 00:20:30 -0500 Subject: [PATCH 07/46] Remove MPS fix no longer needed for PyTorch The torch.narrow fix was required for nightly PyTorch builds for a while to prevent a hard crash, but newer nightly builds don't have this issue. --- modules/devices.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 0100e4af..be542f8f 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -201,6 +201,3 @@ if has_mps(): cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) - orig_narrow = torch.narrow - torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) - From 4aa7f5b5b996c1e3d97640e746f040a23a124860 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 11:11:47 +0300 Subject: [PATCH 08/46] update image parameters regex for #7231 --- modules/generation_parameters_copypaste.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 773c5c0e..1bf35bbb 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -11,7 +11,7 @@ from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image -re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' +re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_imagesize = re.compile(r"^(\d+)x(\d+)$") From d04e3e921e8ee71442a1f4a1d6e91c05b8238007 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 15:24:29 +0300 Subject: [PATCH 09/46] automatically detect v-parameterization for SD2 checkpoints --- modules/sd_hijack.py | 2 ++ modules/sd_models_config.py | 51 +++++++++++++++++++++++++++++++++---- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f9652d21..03897b2a 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -131,6 +131,8 @@ class StableDiffusionModelHijack: m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped m.cond_stage_model = m.cond_stage_model.wrapped + undo_optimizations() + self.apply_circular(False) self.layers = None self.clip = None diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 00217990..91c21700 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -1,7 +1,9 @@ import re import os -from modules import shared, paths +import torch + +from modules import shared, paths, sd_disable_initialization sd_configs_path = shared.sd_configs_path sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") @@ -16,12 +18,51 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml" config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") -re_parametrization_v = re.compile(r'-v\b') + +def is_using_v_parameterization_for_sd2(state_dict): + """ + Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome. + """ + + import ldm.modules.diffusionmodules.openaimodel + from modules import devices + + device = devices.cpu + + with sd_disable_initialization.DisableInitialization(): + unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( + use_checkpoint=True, + use_fp16=False, + image_size=32, + in_channels=4, + out_channels=4, + model_channels=320, + attention_resolutions=[4, 2, 1], + num_res_blocks=2, + channel_mult=[1, 2, 4, 4], + num_head_channels=64, + use_spatial_transformer=True, + use_linear_in_transformer=True, + transformer_depth=1, + context_dim=1024, + legacy=False + ) + unet.eval() + + with torch.no_grad(): + unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} + unet.load_state_dict(unet_sd, strict=True) + unet.to(device=device, dtype=torch.float) + + test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 + x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 + + out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item() + + return out < -1 def guess_model_config_from_state_dict(sd, filename): - fn = os.path.basename(filename) - sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) @@ -31,7 +72,7 @@ def guess_model_config_from_state_dict(sd, filename): if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: if diffusion_model_input.shape[1] == 9: return config_sd2_inpainting - elif re.search(re_parametrization_v, fn): + elif is_using_v_parameterization_for_sd2(sd): return config_sd2v else: return config_sd2 From f8feeaaedb890de1e36eeb2ad387f0eb3abafd54 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 15:57:56 +0300 Subject: [PATCH 10/46] add progressbar to extension update check; do not check for updates for disabled extensions --- javascript/extensions.js | 20 +++++++++++++++++--- modules/ui_extensions.py | 28 ++++++++++++++++++---------- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/javascript/extensions.js b/javascript/extensions.js index ac6e35b9..c593cd2e 100644 --- a/javascript/extensions.js +++ b/javascript/extensions.js @@ -1,7 +1,8 @@ function extensions_apply(_, _){ - disable = [] - update = [] + var disable = [] + var update = [] + gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ if(x.name.startsWith("enable_") && ! x.checked) disable.push(x.name.substr(7)) @@ -16,11 +17,24 @@ function extensions_apply(_, _){ } function extensions_check(){ + var disable = [] + + gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ + if(x.name.startsWith("enable_") && ! x.checked) + disable.push(x.name.substr(7)) + }) + gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ x.innerHTML = "Loading..." }) - return [] + + var id = randomId() + requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){ + + }) + + return [id, JSON.stringify(disable)] } function install_extension_from_index(button, url){ diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 66a41865..37d30e1f 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -13,7 +13,7 @@ import shutil import errno from modules import extensions, shared, paths - +from modules.call_queue import wrap_gradio_gpu_call available_extensions = {"extensions": []} @@ -50,12 +50,17 @@ def apply_and_restart(disable_list, update_list): shared.state.need_restart = True -def check_updates(): +def check_updates(id_task, disable_list): check_access() - for ext in extensions.extensions: - if ext.remote is None: - continue + disabled = json.loads(disable_list) + assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" + + exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled] + shared.state.job_count = len(exts) + + for ext in exts: + shared.state.textinfo = ext.name try: ext.check_updates() @@ -63,7 +68,9 @@ def check_updates(): print(f"Error checking updates for {ext.name}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - return extension_table() + shared.state.nextjob() + + return extension_table(), "" def extension_table(): @@ -273,12 +280,13 @@ def create_ui(): with gr.Tabs(elem_id="tabs_extensions") as tabs: with gr.TabItem("Installed"): - with gr.Row(): + with gr.Row(elem_id="extensions_installed_top"): apply = gr.Button(value="Apply and restart UI", variant="primary") check = gr.Button(value="Check for updates") extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False) extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False) + info = gr.HTML() extensions_table = gr.HTML(lambda: extension_table()) apply.click( @@ -289,10 +297,10 @@ def create_ui(): ) check.click( - fn=check_updates, + fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]), _js="extensions_check", - inputs=[], - outputs=[extensions_table], + inputs=[info, extensions_disabled_list], + outputs=[extensions_table, info], ) with gr.TabItem("Available"): From 5d14f282c2812888275902be4b552681f942dbfd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 16:23:49 +0300 Subject: [PATCH 11/46] fixed a bug where after switching to a checkpoint with unknown hash, you'd get empty space instead of checkpoint name in UI fixed a bug where if you update a selected checkpoint on disk and then restart the program, a different checkpoint loads, but the name is shown for the the old one. --- modules/sd_models.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index b2d48a51..c45ddf83 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -231,12 +231,10 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): - title = checkpoint_info.title sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - if checkpoint_info.title != title: - shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) From 1421e959600e0e9a2435e48373a551237bbab814 Mon Sep 17 00:00:00 2001 From: Thurion Date: Sat, 28 Jan 2023 14:42:24 +0100 Subject: [PATCH 12/46] allow empty mask dir --- modules/img2img.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index fe9447c7..3ecb6146 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -21,8 +21,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): images = shared.listfiles(input_dir) - inpaint_masks = shared.listfiles(inpaint_mask_dir) - is_inpaint_batch = inpaint_mask_dir and len(inpaint_masks) > 0 + is_inpaint_batch = False + if inpaint_mask_dir: + inpaint_masks = shared.listfiles(inpaint_mask_dir) + is_inpaint_batch = len(inpaint_masks) > 0 if is_inpaint_batch: print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") From b7d2af8c7fa48d6eef7517a6fbc63a3507c638d4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 17:18:47 +0300 Subject: [PATCH 13/46] add dropdowns in settings for hypernets and loras --- extensions-builtin/Lora/extra_networks_lora.py | 8 +++++++- extensions-builtin/Lora/scripts/lora_script.py | 3 +++ modules/extra_networks_hypernet.py | 8 +++++++- modules/shared.py | 5 +++-- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index 8f2e753e..6be6ef73 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -1,4 +1,4 @@ -from modules import extra_networks +from modules import extra_networks, shared import lora class ExtraNetworkLora(extra_networks.ExtraNetwork): @@ -6,6 +6,12 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): super().__init__('lora') def activate(self, p, params_list): + additional = shared.opts.sd_lora + + if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0: + p.all_prompts = [x + f"" for x in p.all_prompts] + params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) + names = [] multipliers = [] for params in params_list: diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 544b228d..2e860160 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -1,4 +1,5 @@ import torch +import gradio as gr import lora import extra_networks_lora @@ -31,5 +32,7 @@ script_callbacks.on_before_ui(before_ui) shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { + "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"), + })) diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py index ff279a1f..d3a4d7ad 100644 --- a/modules/extra_networks_hypernet.py +++ b/modules/extra_networks_hypernet.py @@ -1,4 +1,4 @@ -from modules import extra_networks +from modules import extra_networks, shared, extra_networks from modules.hypernetworks import hypernetwork @@ -7,6 +7,12 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): super().__init__('hypernet') def activate(self, p, params_list): + additional = shared.opts.sd_hypernetwork + + if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: + p.all_prompts = [x + f"" for x in p.all_prompts] + params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) + names = [] multipliers = [] for params in params_list: diff --git a/modules/shared.py b/modules/shared.py index 474fcc42..eb04e811 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -405,7 +405,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), - "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), })) @@ -431,7 +430,9 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), })) options_templates.update(options_section(('extra_networks', "Extra Networks"), { - "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, { "choices": ["cards", "thumbs"] }), + "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}), + "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), })) options_templates.update(options_section(('ui', "User interface"), { From 591b68e56c53eed391d08ce008423191c573784d Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sat, 28 Jan 2023 10:04:09 -0500 Subject: [PATCH 14/46] uses autos new regex, checks len of re_param --- modules/generation_parameters_copypaste.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 13d0874d..53f1a865 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -11,7 +11,7 @@ from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image -re_param_code = r'\s*([\w ]+):\s*(\"[^\"]*\"|[^,]+)' +re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") @@ -242,7 +242,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model done_with_prompt = False *lines, lastline = x.strip().split("\n") - if not re_param.match(lastline): + if len(re_param.findall(lastline)) < 3: lines.append(lastline) lastline = '' From f4eeff659e18fc7683f426371394f48b58a00bd3 Mon Sep 17 00:00:00 2001 From: ItsOlegDm Date: Sat, 28 Jan 2023 17:05:08 +0200 Subject: [PATCH 15/46] Removed buttons centering --- style.css | 1 - 1 file changed, 1 deletion(-) diff --git a/style.css b/style.css index 1e90b105..3cbabfd6 100644 --- a/style.css +++ b/style.css @@ -136,7 +136,6 @@ #txt2img_tools, #img2img_tools{ gap: 0.4em; - justify-content: center; } #interrogate_col{ From 1e22f48f4dbef15d8b2ba353b6c3cd68c4d0b42e Mon Sep 17 00:00:00 2001 From: ItsOlegDm Date: Sat, 28 Jan 2023 17:08:38 +0200 Subject: [PATCH 16/46] img2img styled padding fix --- style.css | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/style.css b/style.css index 3cbabfd6..a4bab9be 100644 --- a/style.css +++ b/style.css @@ -156,14 +156,13 @@ #txt2img_styles_row, #img2img_styles_row{ gap: 0.25em; + margin-top: 0.3em; } #txt2img_styles_row > button, #img2img_styles_row > button{ margin: 0; } -#txt2img_styles_row { - margin-top: 0.3em; -} + #txt2img_styles, #img2img_styles{ padding: 0; } From e2c71a4bd41470b9503021db36be2ae65f345d97 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 18:12:53 +0300 Subject: [PATCH 17/46] make prevent the browser from using cached version of scripts when they change --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 9f4cfda1..4e082408 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1692,14 +1692,14 @@ def create_ui(): def reload_javascript(): - head = f'\n' + head = f'\n' inline = f"{localization.localization_js(shared.opts.localization)};" if cmd_opts.theme is not None: inline += f"set_theme('{cmd_opts.theme}');" for script in modules.scripts.list_scripts("javascript", ".js"): - head += f'\n' + head += f'\n' head += f'\n' From 29d2d6a094a1b4028b8d281f069f28bd4cacc944 Mon Sep 17 00:00:00 2001 From: ItsOlegDm Date: Sat, 28 Jan 2023 17:21:59 +0200 Subject: [PATCH 18/46] Train tab fix --- style.css | 1 - 1 file changed, 1 deletion(-) diff --git a/style.css b/style.css index a4bab9be..39312c89 100644 --- a/style.css +++ b/style.css @@ -765,7 +765,6 @@ footer { .dark .gr-compact{ background-color: rgb(31 41 55 / var(--tw-bg-opacity)); - align-items: center; margin-left: 0; } From 1d8e06d542176beade37d2d36cb57460c3c6772f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 22:52:27 +0300 Subject: [PATCH 19/46] add checkpoints tab for extra networks UI --- .../Lora/ui_extra_networks_lora.py | 2 +- javascript/ui.js | 7 ++++ modules/ui.py | 8 ++++ modules/ui_extra_networks.py | 37 ++++++++++++++++-- modules/ui_extra_networks_checkpoints.py | 38 +++++++++++++++++++ modules/ui_extra_networks_hypernets.py | 2 +- .../ui_extra_networks_textual_inversion.py | 2 +- webui.py | 6 ++- 8 files changed, 94 insertions(+), 8 deletions(-) create mode 100644 modules/ui_extra_networks_checkpoints.py diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 54a80d36..c1244b10 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -20,7 +20,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): preview = None for file in previews: if os.path.isfile(file): - preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file)) + preview = self.link_preview(file) break yield { diff --git a/javascript/ui.js b/javascript/ui.js index ba72623c..dd40e62d 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -309,3 +309,10 @@ function updateInput(target){ Object.defineProperty(e, "target", {value: target}) target.dispatchEvent(e); } + + +var desiredCheckpointName = null; +function selectCheckpoint(name){ + desiredCheckpointName = name; + gradioApp().getElementById('change_checkpoint').click() +} diff --git a/modules/ui.py b/modules/ui.py index 4e082408..f1195692 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1560,6 +1560,14 @@ def create_ui(): outputs=[component, text_settings], ) + button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) + button_set_checkpoint.click( + fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'), + _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", + inputs=[component_dict['sd_model_checkpoint'], dummy_component], + outputs=[component_dict['sd_model_checkpoint'], text_settings], + ) + component_keys = [k for k in opts.data_labels.keys() if k in component_dict] def get_settings_values(): diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index c6ff889a..5730c879 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,4 +1,6 @@ import os.path +import urllib.parse +from pathlib import Path from modules import shared import gradio as gr @@ -8,12 +10,31 @@ import html from modules.generation_parameters_copypaste import image_from_url_text extra_pages = [] +allowed_dirs = set() def register_page(page): """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" extra_pages.append(page) + allowed_dirs.clear() + allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], []))) + + +def add_pages_to_demo(app): + def fetch_file(filename: str = ""): + from starlette.responses import FileResponse + + if not any([Path(x).resolve() in Path(filename).resolve().parents for x in allowed_dirs]): + raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") + + if os.path.splitext(filename)[1].lower() != ".png": + raise ValueError(f"File cannot be fetched: {filename}. Only png.") + + # would profit from returning 304 + return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) + + app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"]) class ExtraNetworksPage: @@ -26,6 +47,9 @@ class ExtraNetworksPage: def refresh(self): pass + def link_preview(self, filename): + return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename)) + def create_html(self, tabname): view = shared.opts.extra_networks_default_view items_html = '' @@ -54,13 +78,17 @@ class ExtraNetworksPage: def create_html_for_item(self, item, tabname): preview = item.get("preview", None) + onclick = item.get("onclick", None) + if onclick is None: + onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + args = { "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '', - "prompt": item["prompt"], + "prompt": item.get("prompt", None), "tabname": json.dumps(tabname), "local_preview": json.dumps(item["local_preview"]), "name": item["name"], - "card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"', + "card_clicked": onclick, "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', } @@ -143,7 +171,7 @@ def path_is_parent(parent_path, child_path): parent_path = os.path.abspath(parent_path) child_path = os.path.abspath(child_path) - return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path]) + return child_path.startswith(parent_path) def setup_ui(ui, gallery): @@ -173,7 +201,8 @@ def setup_ui(ui, gallery): ui.button_save_preview.click( fn=save_preview, - _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}", + _js="function(x, y, z){return [selected_gallery_index(), y, z]}", inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], outputs=[*ui.pages] ) + diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py new file mode 100644 index 00000000..c66cb830 --- /dev/null +++ b/modules/ui_extra_networks_checkpoints.py @@ -0,0 +1,38 @@ +import html +import json +import os +import urllib.parse + +from modules import shared, ui_extra_networks, sd_models + + +class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Checkpoints') + + def refresh(self): + shared.refresh_checkpoints() + + def list_items(self): + for name, checkpoint1 in sd_models.checkpoints_list.items(): + checkpoint: sd_models.CheckpointInfo = checkpoint1 + path, ext = os.path.splitext(checkpoint.filename) + previews = [path + ".png", path + ".preview.png"] + + preview = None + for file in previews: + if os.path.isfile(file): + preview = self.link_preview(file) + break + + yield { + "name": checkpoint.model_name, + "filename": path, + "preview": preview, + "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', + "local_preview": path + ".png", + } + + def allowed_directories_for_previews(self): + return [shared.cmd_opts.ckpt_dir, sd_models.model_path] + diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 65d000cf..8c15f8eb 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -19,7 +19,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): preview = None for file in previews: if os.path.isfile(file): - preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file)) + preview = self.link_preview(file) break yield { diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index dbd23d2d..a9d3064b 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -19,7 +19,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): preview = None if os.path.isfile(preview_file): - preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file)) + preview = self.link_preview(preview_file) yield { "name": embedding.name, diff --git a/webui.py b/webui.py index 41f32f5c..0d0b8364 100644 --- a/webui.py +++ b/webui.py @@ -12,7 +12,7 @@ from packaging import version import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) -from modules import import_hook, errors, extra_networks +from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call @@ -119,6 +119,7 @@ def initialize(): ui_extra_networks.intialize() ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) extra_networks.initialize() extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) @@ -227,6 +228,8 @@ def webui(): if launch_api: create_api(app) + ui_extra_networks.add_pages_to_demo(app) + modules.script_callbacks.app_started_callback(shared.demo, app) wait_on_server(shared.demo) @@ -254,6 +257,7 @@ def webui(): ui_extra_networks.intialize() ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) extra_networks.initialize() extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) From 0a8515085ef258d4b76fdc000f7ed9d55751d6b8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 23:31:48 +0300 Subject: [PATCH 20/46] make it so that clicking on hypernet/lora card one more time removes the related from the prompt --- javascript/extraNetworks.js | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index c5a9adb3..b5536a34 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -48,10 +48,39 @@ function setupExtraNetworks(){ onUiLoaded(setupExtraNetworks) +var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/; +var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g; + +function tryToRemoveExtraNetworkFromPrompt(textarea, text){ + var m = text.match(re_extranet) + if(! m) return false + + var partToSearch = m[1] + var replaced = false + var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){ + m = found.match(re_extranet); + if(m[1] == partToSearch){ + replaced = true; + return "" + } + return found; + }) + + if(replaced){ + textarea.value = newTextareaText + return true; + } + + return false +} + function cardClicked(tabname, textToAdd, allowNegativePrompt){ var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea") - textarea.value = textarea.value + " " + textToAdd + if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){ + textarea.value = textarea.value + " " + textToAdd + } + updateInput(textarea) } From 09a142a05a6da8bdd4f36678a098c2a573db181a Mon Sep 17 00:00:00 2001 From: glop102 Date: Sat, 28 Jan 2023 19:25:52 -0500 Subject: [PATCH 21/46] Reduce grid rows if larger than number of images available When a set number of grid rows is specified in settings, then it leads to situations where an entire row in the grid is empty. The most noticable example is the processing preview when the row count is set to 2, where it shows the preview just fine but with a black rectangle under it. --- modules/images.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/images.py b/modules/images.py index 0bc3d524..ae3cdaf4 100644 --- a/modules/images.py +++ b/modules/images.py @@ -36,6 +36,8 @@ def image_grid(imgs, batch_size=1, rows=None): else: rows = math.sqrt(len(imgs)) rows = round(rows) + if rows > len(imgs): + rows = len(imgs) cols = math.ceil(len(imgs) / rows) From f6b7768f84a335d351ba8c0d4c34d78e59272339 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 29 Jan 2023 10:20:19 +0300 Subject: [PATCH 22/46] support for searching subdirectory names for extra networks --- extensions-builtin/Lora/ui_extra_networks_lora.py | 1 + html/extra-networks-card.html | 1 + javascript/extraNetworks.js | 2 +- modules/sd_models.py | 1 + modules/ui_extra_networks.py | 11 +++++++++++ modules/ui_extra_networks_checkpoints.py | 6 +++--- modules/ui_extra_networks_hypernets.py | 1 + modules/ui_extra_networks_textual_inversion.py | 1 + 8 files changed, 20 insertions(+), 4 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index c1244b10..22cabcb0 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -27,6 +27,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": preview, + "search_term": self.search_terms_from_path(lora_on_disk.filename), "prompt": json.dumps(f""), "local_preview": path + ".png", } diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index aa9fca87..8a5e2fbd 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -4,6 +4,7 @@ + {name} diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index b5536a34..231fafe5 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -16,7 +16,7 @@ function setupExtraNetworksForTab(tabname){ searchTerm = search.value.toLowerCase() gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){ - text = elem.querySelector('.name').textContent.toLowerCase() + text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase() elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : "" }) }); diff --git a/modules/sd_models.py b/modules/sd_models.py index c45ddf83..300387a9 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -41,6 +41,7 @@ class CheckpointInfo: name = name[1:] self.name = name + self.name_for_extra = os.path.splitext(os.path.basename(filename))[0] self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.hash = model_hash(filename) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 5730c879..29c6e196 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -50,6 +50,16 @@ class ExtraNetworksPage: def link_preview(self, filename): return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename)) + def search_terms_from_path(self, filename, possible_directories=None): + abspath = os.path.abspath(filename) + + for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()): + parentdir = os.path.abspath(parentdir) + if abspath.startswith(parentdir): + return abspath[len(parentdir):].replace('\\','/') + + return "" + def create_html(self, tabname): view = shared.opts.extra_networks_default_view items_html = '' @@ -90,6 +100,7 @@ class ExtraNetworksPage: "name": item["name"], "card_clicked": onclick, "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', + "search_term": item.get("search_term", ""), } return self.card_page.format(**args) diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index c66cb830..360579b0 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -14,8 +14,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): shared.refresh_checkpoints() def list_items(self): - for name, checkpoint1 in sd_models.checkpoints_list.items(): - checkpoint: sd_models.CheckpointInfo = checkpoint1 + for name, checkpoint in sd_models.checkpoints_list.items(): path, ext = os.path.splitext(checkpoint.filename) previews = [path + ".png", path + ".preview.png"] @@ -26,9 +25,10 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): break yield { - "name": checkpoint.model_name, + "name": checkpoint.name_for_extra, "filename": path, "preview": preview, + "search_term": self.search_terms_from_path(checkpoint.filename), "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', "local_preview": path + ".png", } diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 8c15f8eb..57851088 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -26,6 +26,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": preview, + "search_term": self.search_terms_from_path(path), "prompt": json.dumps(f""), "local_preview": path + ".png", } diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index a9d3064b..bb64eb81 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -25,6 +25,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): "name": embedding.name, "filename": embedding.filename, "preview": preview, + "search_term": self.search_terms_from_path(embedding.filename), "prompt": json.dumps(embedding.name), "local_preview": path + ".preview.png", } From 659d602dce42608a664642021ea2441da045d189 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sun, 29 Jan 2023 02:32:53 -0500 Subject: [PATCH 23/46] only returns ckpt directories if they are not none --- modules/ui_extra_networks_checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index c66cb830..5b471671 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -34,5 +34,5 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): } def allowed_directories_for_previews(self): - return [shared.cmd_opts.ckpt_dir, sd_models.model_path] + return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] From 8d7382ab24756cdcc37e71406832814f4713c55e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 29 Jan 2023 11:34:58 +0300 Subject: [PATCH 24/46] add buttons for auto-search in subdirectories for extra tabs --- javascript/extraNetworks.js | 9 +++++++++ modules/ui_extra_networks.py | 27 ++++++++++++++++++++++++++- style.css | 6 ++++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 231fafe5..17bf2000 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -96,3 +96,12 @@ function saveCardPreview(event, tabname, filename){ event.stopPropagation() event.preventDefault() } + +function extraNetworksSearchButton(tabs_id, event){ + searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea') + button = event.target + text = button.classList.contains("search-all") ? "" : button.textContent.trim() + + searchTextarea.value = text + updateInput(searchTextarea) +} \ No newline at end of file diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 29c6e196..83367968 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,3 +1,4 @@ +import glob import os.path import urllib.parse from pathlib import Path @@ -56,7 +57,7 @@ class ExtraNetworksPage: for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()): parentdir = os.path.abspath(parentdir) if abspath.startswith(parentdir): - return abspath[len(parentdir):].replace('\\','/') + return abspath[len(parentdir):].replace('\\', '/') return "" @@ -64,6 +65,27 @@ class ExtraNetworksPage: view = shared.opts.extra_networks_default_view items_html = '' + subdirs = {} + for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: + for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True): + if not os.path.isdir(x): + continue + + subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") + while subdir.startswith("/"): + subdir = subdir[1:] + + subdirs[subdir] = 1 + + if subdirs: + subdirs = {"": 1, **subdirs} + + subdirs_html = "".join([f""" + +""" for subdir in subdirs]) + for item in self.list_items(): items_html += self.create_html_for_item(item, tabname) @@ -72,6 +94,9 @@ class ExtraNetworksPage: items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) res = f""" +
+{subdirs_html} +
{items_html}
diff --git a/style.css b/style.css index 39312c89..05572f66 100644 --- a/style.css +++ b/style.css @@ -807,7 +807,13 @@ footer { margin: 0.3em; } +.extra-network-subdirs{ + padding: 0.2em 0.35em; +} +.extra-network-subdirs button{ + margin: 0 0.15em; +} #txt2img_extra_networks .search, #img2img_extra_networks .search{ display: inline-block; From aa6e55e00140da6d73d3d360a5628c1b1316550d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 29 Jan 2023 11:53:05 +0300 Subject: [PATCH 25/46] do not display the message for TI unless the list of loaded embeddings changed --- modules/textual_inversion/textual_inversion.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 6cf00e65..a1a406c2 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -112,6 +112,7 @@ class EmbeddingDatabase: self.skipped_embeddings = {} self.expected_shape = -1 self.embedding_dirs = {} + self.previously_displayed_embeddings = () def add_embedding_dir(self, path): self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) @@ -228,9 +229,12 @@ class EmbeddingDatabase: self.load_from_dir(embdir) embdir.update() - print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") - if len(self.skipped_embeddings) > 0: - print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") + displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys())) + if self.previously_displayed_embeddings != displayed_embeddings: + self.previously_displayed_embeddings = displayed_embeddings + print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") + if len(self.skipped_embeddings) > 0: + print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") def find_embedding_at_position(self, tokens, offset): token = tokens[offset] From 00dab8f10defbbda579a1bc89c8d4e972c58a20d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 29 Jan 2023 11:53:24 +0300 Subject: [PATCH 26/46] remove Batch size and Batch pos from textinfo (goodbye) --- modules/processing.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index afab6790..2d295932 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -450,8 +450,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Batch size": (None if p.batch_size < 2 else p.batch_size), - "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), From 938578e8a94883aa3c0075cf47eea64f66119541 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 00:25:30 +0300 Subject: [PATCH 27/46] make it so that setting options in pasted infotext (like Clip Skip and ENSD) do not get applied directly and instead are added as temporary overrides --- modules/generation_parameters_copypaste.py | 201 ++++++++++++++------- modules/shared.py | 37 +++- modules/txt2img.py | 6 +- modules/ui.py | 40 +++- modules/ui_common.py | 6 +- 5 files changed, 210 insertions(+), 80 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 3c098e0d..1292fead 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,4 +1,5 @@ import base64 +import html import io import math import os @@ -16,13 +17,23 @@ re_param = re.compile(re_param_code) re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") type_of_gr_update = type(gr.update()) + paste_fields = {} -bind_list = [] +registered_param_bindings = [] + + +class ParamBinding: + def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None): + self.paste_button = paste_button + self.tabname = tabname + self.source_text_component = source_text_component + self.source_image_component = source_image_component + self.source_tabname = source_tabname + self.override_settings_component = override_settings_component def reset(): paste_fields.clear() - bind_list.clear() def quote(text): @@ -74,26 +85,6 @@ def add_paste_fields(tabname, init_img, fields): modules.ui.img2img_paste_fields = fields -def integrate_settings_paste_fields(component_dict): - from modules import ui - - settings_map = { - 'CLIP_stop_at_last_layers': 'Clip skip', - 'inpainting_mask_weight': 'Conditional mask weight', - 'sd_model_checkpoint': 'Model hash', - 'eta_noise_seed_delta': 'ENSD', - 'initial_noise_multiplier': 'Noise multiplier', - } - settings_paste_fields = [ - (component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None))) - for k, v in settings_map.items() - ] - - for tabname, info in paste_fields.items(): - if info["fields"] is not None: - info["fields"] += settings_paste_fields - - def create_buttons(tabs_list): buttons = {} for tab in tabs_list: @@ -101,9 +92,60 @@ def create_buttons(tabs_list): return buttons -#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab def bind_buttons(buttons, send_image, send_generate_info): - bind_list.append([buttons, send_image, send_generate_info]) + """old function for backwards compatibility; do not use this, use register_paste_params_button""" + for tabname, button in buttons.items(): + source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None + source_tabname = send_generate_info if isinstance(send_generate_info, str) else None + + register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname)) + + +def register_paste_params_button(binding: ParamBinding): + registered_param_bindings.append(binding) + + +def connect_paste_params_buttons(): + binding: ParamBinding + for binding in registered_param_bindings: + destination_image_component = paste_fields[binding.tabname]["init_img"] + fields = paste_fields[binding.tabname]["fields"] + + destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None) + destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None) + + if binding.source_image_component and destination_image_component: + if isinstance(binding.source_image_component, gr.Gallery): + func = send_image_and_dimensions if destination_width_component else image_from_url_text + jsfunc = "extract_image_from_gallery" + else: + func = send_image_and_dimensions if destination_width_component else lambda x: x + jsfunc = None + + binding.paste_button.click( + fn=func, + _js=jsfunc, + inputs=[binding.source_image_component], + outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component], + ) + + if binding.source_text_component is not None and fields is not None: + connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component) + + if binding.source_tabname is not None and fields is not None: + paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_button.click( + fn=lambda *x: x, + inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names], + outputs=[field for field, name in fields if name in paste_field_names], + ) + + binding.paste_button.click( + fn=None, + _js=f"switch_to_{binding.tabname}", + inputs=None, + outputs=None, + ) def send_image_and_dimensions(x): @@ -122,49 +164,6 @@ def send_image_and_dimensions(x): return img, w, h -def run_bind(): - for buttons, source_image_component, send_generate_info in bind_list: - for tab in buttons: - button = buttons[tab] - destination_image_component = paste_fields[tab]["init_img"] - fields = paste_fields[tab]["fields"] - - destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None) - destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None) - - if source_image_component and destination_image_component: - if isinstance(source_image_component, gr.Gallery): - func = send_image_and_dimensions if destination_width_component else image_from_url_text - jsfunc = "extract_image_from_gallery" - else: - func = send_image_and_dimensions if destination_width_component else lambda x: x - jsfunc = None - - button.click( - fn=func, - _js=jsfunc, - inputs=[source_image_component], - outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component], - ) - - if send_generate_info and fields is not None: - if send_generate_info in paste_fields: - paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) - button.click( - fn=lambda *x: x, - inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names], - outputs=[field for field, name in fields if name in paste_field_names], - ) - else: - connect_paste(button, fields, send_generate_info) - - button.click( - fn=None, - _js=f"switch_to_{tab}", - inputs=None, - outputs=None, - ) - def find_hypernetwork_key(hypernet_name, hypernet_hash=None): """Determines the config parameter name to use for the hypernet based on the parameters in the infotext. @@ -286,7 +285,47 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model return res -def connect_paste(button, paste_fields, input_comp, jsfunc=None): +settings_map = {} + +infotext_to_setting_name_mapping = [ + ('Clip skip', 'CLIP_stop_at_last_layers', ), + ('Conditional mask weight', 'inpainting_mask_weight'), + ('Model hash', 'sd_model_checkpoint'), + ('ENSD', 'eta_noise_seed_delta'), + ('Noise multiplier', 'initial_noise_multiplier'), +] + + +def create_override_settings_dict(text_pairs): + """creates processing's override_settings parameters from gradio's multiselect + + Example input: + ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337'] + + Example output: + {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337} + """ + + res = {} + + params = {} + for pair in text_pairs: + k, v = pair.split(":", maxsplit=1) + + params[k] = v.strip() + + for param_name, setting_name in infotext_to_setting_name_mapping: + value = params.get(param_name, None) + + if value is None: + continue + + res[setting_name] = shared.opts.cast_value(setting_name, value) + + return res + + +def connect_paste(button, paste_fields, input_comp, override_settings_component, jsfunc=None): def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: filename = os.path.join(data_path, "params.txt") @@ -323,6 +362,32 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None): return res + if override_settings_component is not None: + def paste_settings(params): + vals = {} + + for param_name, setting_name in infotext_to_setting_name_mapping: + v = params.get(param_name, None) + if v is None: + continue + + if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap: + continue + + v = shared.opts.cast_value(setting_name, v) + current_value = getattr(shared.opts, setting_name, None) + + if v == current_value: + continue + + vals[param_name] = v + + vals_pairs = [f"{k}: {v}" for k, v in vals.items()] + + return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0) + + paste_fields = paste_fields + [(override_settings_component, paste_settings)] + button.click( fn=paste_func, _js=jsfunc, diff --git a/modules/shared.py b/modules/shared.py index eb04e811..b5370265 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -127,12 +127,13 @@ restricted_opts = { ui_reorder_categories = [ "inpaint", "sampler", + "checkboxes", + "hires_fix", "dimensions", "cfg", "seed", - "checkboxes", - "hires_fix", "batch", + "override_settings", "scripts", ] @@ -346,10 +347,10 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), { })) options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { - "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"), - "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"), + "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"), + "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"), "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"), - "directories_filename_pattern": OptionInfo("", "Directory name pattern", component_args=hide_dirs), + "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs), "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}), })) @@ -605,11 +606,37 @@ class Options: self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])} + def cast_value(self, key, value): + """casts an arbitrary to the same type as this setting's value with key + Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str) + """ + + if value is None: + return None + + default_value = self.data_labels[key].default + if default_value is None: + default_value = getattr(self, key, None) + if default_value is None: + return None + + expected_type = type(default_value) + if expected_type == bool and value == "False": + value = False + else: + value = expected_type(value) + + return value + + opts = Options() if os.path.exists(config_filename): opts.load(config_filename) +settings_components = None +"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings""" + latent_upscale_default_mode = "Latent" latent_upscale_modes = { "Latent": {"mode": "bilinear", "antialias": False}, diff --git a/modules/txt2img.py b/modules/txt2img.py index e945fd69..16841d0f 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,5 +1,6 @@ import modules.scripts from modules import sd_samplers +from modules.generation_parameters_copypaste import create_override_settings_dict from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \ StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, cmd_opts @@ -8,7 +9,9 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args): + override_settings = create_override_settings_dict(override_settings_texts) + p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -38,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step hr_second_pass_steps=hr_second_pass_steps, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y, + override_settings=override_settings, ) p.scripts = modules.scripts.scripts_txt2img diff --git a/modules/ui.py b/modules/ui.py index f1195692..a7fcdd83 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -380,6 +380,7 @@ def apply_setting(key, value): opts.save(shared.config_filename) return getattr(opts, key) + def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def refresh(): refresh_method() @@ -433,6 +434,18 @@ def get_value_for_setting(key): return gr.update(value=value, **args) +def create_override_settings_dropdown(tabname, row): + dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True) + + dropdown.change( + fn=lambda x: gr.Dropdown.update(visible=len(x) > 0), + inputs=[dropdown], + outputs=[dropdown], + ) + + return dropdown + + def create_ui(): import modules.img2img import modules.txt2img @@ -503,6 +516,10 @@ def create_ui(): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + elif category == "override_settings": + with FormRow(elem_id="txt2img_override_settings_row") as row: + override_settings = create_override_settings_dropdown('txt2img', row) + elif category == "scripts": with FormGroup(elem_id="txt2img_script_container"): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() @@ -524,7 +541,6 @@ def create_ui(): ) txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) - parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -555,6 +571,7 @@ def create_ui(): hr_second_pass_steps, hr_resize_x, hr_resize_y, + override_settings, ] + custom_inputs, outputs=[ @@ -615,6 +632,9 @@ def create_ui(): *modules.scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None, override_settings_component=override_settings, + )) txt2img_preview_params = [ txt2img_prompt, @@ -762,6 +782,10 @@ def create_ui(): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + elif category == "override_settings": + with FormRow(elem_id="img2img_override_settings_row") as row: + override_settings = create_override_settings_dropdown('img2img', row) + elif category == "scripts": with FormGroup(elem_id="img2img_script_container"): custom_inputs = modules.scripts.scripts_img2img.setup_ui() @@ -796,7 +820,6 @@ def create_ui(): ) img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) - parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -937,6 +960,9 @@ def create_ui(): ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, + )) modules.scripts.scripts_current = None @@ -954,7 +980,11 @@ def create_ui(): html2 = gr.HTML() with gr.Row(): buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) - parameters_copypaste.bind_buttons(buttons, image, generation_info) + + for tabname, button in buttons.items(): + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image, + )) image.change( fn=wrap_gradio_call(modules.extras.run_pnginfo), @@ -1363,6 +1393,7 @@ def create_ui(): components = [] component_dict = {} + shared.settings_components = component_dict script_callbacks.ui_settings_callback() opts.reorder() @@ -1529,8 +1560,7 @@ def create_ui(): component = create_setting_component(k, is_quicksettings=True) component_dict[k] = component - parameters_copypaste.integrate_settings_paste_fields(component_dict) - parameters_copypaste.run_bind() + parameters_copypaste.connect_paste_params_buttons() with gr.Tabs(elem_id="tabs") as tabs: for interface, label, ifid in interfaces: diff --git a/modules/ui_common.py b/modules/ui_common.py index 9405ac1f..fd047f31 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -198,5 +198,9 @@ Requested path was: {f} html_info = gr.HTML(elem_id=f'html_info_{tabname}') html_log = gr.HTML(elem_id=f'html_log_{tabname}') - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) + for paste_tabname, paste_button in buttons.items(): + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( + paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery + )) + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log From f91068f426a359942d13bf7ec15b56562141b8d7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 00:37:26 +0300 Subject: [PATCH 28/46] change disable_weights_auto_swap to true by default --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index b5370265..96a2572f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -441,7 +441,7 @@ options_templates.update(options_section(('ui', "User interface"), { "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), - "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), + "disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), "font": OptionInfo("", "Font for image grids that have text"), From 399720dac2543fb4cdbe1022ec1a01f2411b81e2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 01:03:31 +0300 Subject: [PATCH 29/46] update prompt token counts after using the paste params button --- javascript/ui.js | 36 +++++++++++++++++----- modules/generation_parameters_copypaste.py | 6 ++-- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/javascript/ui.js b/javascript/ui.js index dd40e62d..b7a8268a 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -191,6 +191,28 @@ function confirm_clear_prompt(prompt, negative_prompt) { return [prompt, negative_prompt] } + +promptTokecountUpdateFuncs = {} + +function recalculatePromptTokens(name){ + if(promptTokecountUpdateFuncs[name]){ + promptTokecountUpdateFuncs[name]() + } +} + +function recalculate_prompts_txt2img(){ + recalculatePromptTokens('txt2img_prompt') + recalculatePromptTokens('txt2img_neg_prompt') + return args_to_array(arguments); +} + +function recalculate_prompts_img2img(){ + recalculatePromptTokens('img2img_prompt') + recalculatePromptTokens('img2img_neg_prompt') + return args_to_array(arguments); +} + + opts = {} onUiUpdate(function(){ if(Object.keys(opts).length != 0) return; @@ -232,14 +254,12 @@ onUiUpdate(function(){ return } - prompt.parentElement.insertBefore(counter, prompt) counter.classList.add("token-counter") prompt.parentElement.style.position = "relative" - textarea.addEventListener("input", function(){ - update_token_counter(id_button); - }); + promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); } + textarea.addEventListener("input", promptTokecountUpdateFuncs[id]); } registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button') @@ -273,7 +293,7 @@ onOptionsChanged(function(){ let txt2img_textarea, img2img_textarea = undefined; let wait_time = 800 -let token_timeout; +let token_timeouts = {}; function update_txt2img_tokens(...args) { update_token_counter("txt2img_token_button") @@ -290,9 +310,9 @@ function update_img2img_tokens(...args) { } function update_token_counter(button_id) { - if (token_timeout) - clearTimeout(token_timeout); - token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); + if (token_timeouts[button_id]) + clearTimeout(token_timeouts[button_id]); + token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); } function restart_reload(){ diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 1292fead..2a10524f 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -130,7 +130,7 @@ def connect_paste_params_buttons(): ) if binding.source_text_component is not None and fields is not None: - connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component) + connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component, binding.tabname) if binding.source_tabname is not None and fields is not None: paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) @@ -325,7 +325,7 @@ def create_override_settings_dict(text_pairs): return res -def connect_paste(button, paste_fields, input_comp, override_settings_component, jsfunc=None): +def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname): def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: filename = os.path.join(data_path, "params.txt") @@ -390,7 +390,7 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, button.click( fn=paste_func, - _js=jsfunc, + _js=f"recalculate_prompts_{tabname}", inputs=[input_comp], outputs=[x[0] for x in paste_fields], ) From 847ceae1f71ee13e0a397da048d1bb418e8f36c1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 01:41:23 +0300 Subject: [PATCH 30/46] make it possible to search checkpoint by its hash --- modules/ui_extra_networks_checkpoints.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index a6799171..04097a79 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -14,6 +14,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): shared.refresh_checkpoints() def list_items(self): + checkpoint: sd_models.CheckpointInfo for name, checkpoint in sd_models.checkpoints_list.items(): path, ext = os.path.splitext(checkpoint.filename) previews = [path + ".png", path + ".preview.png"] @@ -28,7 +29,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): "name": checkpoint.name_for_extra, "filename": path, "preview": preview, - "search_term": self.search_terms_from_path(checkpoint.filename), + "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', "local_preview": path + ".png", } From c81b52ffbd6252842b3473a7aa8eb7ffc88ee7d1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 02:40:26 +0300 Subject: [PATCH 31/46] add override settings component to img2img --- modules/img2img.py | 6 +++++- modules/ui.py | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 3ecb6146..f813299c 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -7,6 +7,7 @@ import numpy as np from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops from modules import devices, sd_samplers +from modules.generation_parameters_copypaste import create_override_settings_dict from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state import modules.shared as shared @@ -75,7 +76,9 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): + override_settings = create_override_settings_dict(override_settings_texts) + is_batch = mode == 5 if mode == 0: # img2img @@ -142,6 +145,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s inpaint_full_res=inpaint_full_res, inpaint_full_res_padding=inpaint_full_res_padding, inpainting_mask_invert=inpainting_mask_invert, + override_settings=override_settings, ) p.scripts = modules.scripts.scripts_txt2img diff --git a/modules/ui.py b/modules/ui.py index a7fcdd83..f910c582 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -872,7 +872,8 @@ def create_ui(): inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, - img2img_batch_inpaint_mask_dir + img2img_batch_inpaint_mask_dir, + override_settings, ] + custom_inputs, outputs=[ img2img_gallery, @@ -961,7 +962,7 @@ def create_ui(): parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, + paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, override_settings_component=override_settings, )) modules.scripts.scripts_current = None From cbd6329488beafe036ea3a3d0cea1a6940105cf9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:12:43 +0300 Subject: [PATCH 32/46] add an environment variable for selecting xformers package --- launch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/launch.py b/launch.py index 370920de..25909469 100644 --- a/launch.py +++ b/launch.py @@ -223,6 +223,7 @@ def prepare_environment(): requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "") + xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425') gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") @@ -282,7 +283,7 @@ def prepare_environment(): if (not is_installed("xformers") or reinstall_xformers) and xformers: if platform.system() == "Windows": if platform.python_version().startswith("3.10"): - run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers") + run_pip(f"install -U -I --no-deps {xformers_package}", "xformers") else: print("Installation of xformers is not supported in this version of Python.") print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness") From 0c7c36a6c6f12da55e04bd79ae068daac8b586a1 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:35:52 +0300 Subject: [PATCH 33/46] Split history sd_samplers.py to sd_samplers_compvis.py --- modules/{sd_samplers.py => sd_samplers_compvis.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{sd_samplers.py => sd_samplers_compvis.py} (100%) diff --git a/modules/sd_samplers.py b/modules/sd_samplers_compvis.py similarity index 100% rename from modules/sd_samplers.py rename to modules/sd_samplers_compvis.py From 9118b086068253c8f25c6277c385606b79c5b036 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:35:52 +0300 Subject: [PATCH 34/46] Split history sd_samplers.py to sd_samplers_compvis.py --- modules/{sd_samplers.py => temp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{sd_samplers.py => temp} (100%) diff --git a/modules/sd_samplers.py b/modules/temp similarity index 100% rename from modules/sd_samplers.py rename to modules/temp From 449531a6c59b030b1cd7c3cba1113c47e0fc1c7d Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:35:53 +0300 Subject: [PATCH 35/46] Split history sd_samplers.py to sd_samplers_compvis.py --- modules/{temp => sd_samplers.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{temp => sd_samplers.py} (100%) diff --git a/modules/temp b/modules/sd_samplers.py similarity index 100% rename from modules/temp rename to modules/sd_samplers.py From 5feae71dd218a3505f14505d71c6b335f9c642ac Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:37:50 +0300 Subject: [PATCH 36/46] Split history sd_samplers.py to sd_samplers_common.py --- modules/{sd_samplers.py => sd_samplers_common.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{sd_samplers.py => sd_samplers_common.py} (100%) diff --git a/modules/sd_samplers.py b/modules/sd_samplers_common.py similarity index 100% rename from modules/sd_samplers.py rename to modules/sd_samplers_common.py From 6e78f6a8961875df11551650b4c5c8bddb6ed9ce Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:37:50 +0300 Subject: [PATCH 37/46] Split history sd_samplers.py to sd_samplers_common.py --- modules/{sd_samplers.py => temp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{sd_samplers.py => temp} (100%) diff --git a/modules/sd_samplers.py b/modules/temp similarity index 100% rename from modules/sd_samplers.py rename to modules/temp From f8fcad502ec97ceb7ca4bf52f0f2efc8b80c0b64 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:37:51 +0300 Subject: [PATCH 38/46] Split history sd_samplers.py to sd_samplers_common.py --- modules/{temp => sd_samplers.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{temp => sd_samplers.py} (100%) diff --git a/modules/temp b/modules/sd_samplers.py similarity index 100% rename from modules/temp rename to modules/sd_samplers.py From aa54a9d41680051b4b28b0655f8d61a2f23600b1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:51:06 +0300 Subject: [PATCH 39/46] split compvis sampler and shared sampler stuff into their own files --- modules/sd_samplers.py | 243 ++--------------- modules/sd_samplers_common.py | 479 +-------------------------------- modules/sd_samplers_compvis.py | 423 +---------------------------- 3 files changed, 28 insertions(+), 1117 deletions(-) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index a7910b56..9a29f1ae 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,22 +1,18 @@ -from collections import namedtuple, deque -import numpy as np -from math import floor +from collections import deque import torch -import tqdm -from PIL import Image import inspect import k_diffusion.sampling -import torchsde._brownian.brownian_interval import ldm.models.diffusion.ddim import ldm.models.diffusion.plms -from modules import prompt_parser, devices, processing, images, sd_vae_approx +from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_compvis -from modules.shared import opts, cmd_opts, state +from modules.shared import opts, state import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback +# imports for functions that previously were here and are used by other modules +from modules.sd_samplers_common import samples_to_image_grid, sample_to_image -SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) samplers_k_diffusion = [ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), @@ -39,15 +35,15 @@ samplers_k_diffusion = [ ] samplers_data_k_diffusion = [ - SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) + sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) for label, funcname, aliases, options in samplers_k_diffusion if hasattr(k_diffusion.sampling, funcname) ] all_samplers = [ *samplers_data_k_diffusion, - SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), - SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), + sd_samplers_common.SamplerData('DDIM', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), + sd_samplers_common.SamplerData('PLMS', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), ] all_samplers_map = {x.name: x for x in all_samplers} @@ -95,202 +91,6 @@ sampler_extra_params = { } -def setup_img2img_steps(p, steps=None): - if opts.img2img_fix_steps or steps is not None: - requested_steps = (steps or p.steps) - steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 - t_enc = requested_steps - 1 - else: - steps = p.steps - t_enc = int(min(p.denoising_strength, 0.999) * steps) - - return steps, t_enc - - -approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2} - - -def single_sample_to_image(sample, approximation=None): - if approximation is None: - approximation = approximation_indexes.get(opts.show_progress_type, 0) - - if approximation == 2: - x_sample = sd_vae_approx.cheap_approximation(sample) - elif approximation == 1: - x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() - else: - x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] - - x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) - return Image.fromarray(x_sample) - - -def sample_to_image(samples, index=0, approximation=None): - return single_sample_to_image(samples[index], approximation) - - -def samples_to_image_grid(samples, approximation=None): - return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples]) - - -def store_latent(decoded): - state.current_latent = decoded - - if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: - if not shared.parallel_processing_allowed: - shared.state.assign_current_image(sample_to_image(decoded)) - - -class InterruptedException(BaseException): - pass - - -class VanillaStableDiffusionSampler: - def __init__(self, constructor, sd_model): - self.sampler = constructor(sd_model) - self.is_plms = hasattr(self.sampler, 'p_sample_plms') - self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim - self.mask = None - self.nmask = None - self.init_latent = None - self.sampler_noises = None - self.step = 0 - self.stop_at = None - self.eta = None - self.default_eta = 0.0 - self.config = None - self.last_latent = None - - self.conditioning_key = sd_model.model.conditioning_key - - def number_of_needed_noises(self, p): - return 0 - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except InterruptedException: - return self.last_latent - - def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): - if state.interrupted or state.skipped: - raise InterruptedException - - if self.stop_at is not None and self.step > self.stop_at: - raise InterruptedException - - # Have to unwrap the inpainting conditioning here to perform pre-processing - image_conditioning = None - if isinstance(cond, dict): - image_conditioning = cond["c_concat"][0] - cond = cond["c_crossattn"][0] - unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) - - assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' - cond = tensor - - # for DDIM, shapes must match, we can't just process cond and uncond independently; - # filling unconditional_conditioning with repeats of the last vector to match length is - # not 100% correct but should work well enough - if unconditional_conditioning.shape[1] < cond.shape[1]: - last_vector = unconditional_conditioning[:, -1:] - last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1]) - unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated]) - elif unconditional_conditioning.shape[1] > cond.shape[1]: - unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]] - - if self.mask is not None: - img_orig = self.sampler.model.q_sample(self.init_latent, ts) - x_dec = img_orig * self.mask + self.nmask * x_dec - - # Wrap the image conditioning back up since the DDIM code can accept the dict directly. - # Note that they need to be lists because it just concatenates them later. - if image_conditioning is not None: - cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) - - if self.mask is not None: - self.last_latent = self.init_latent * self.mask + self.nmask * res[1] - else: - self.last_latent = res[1] - - store_latent(self.last_latent) - - self.step += 1 - state.sampling_step = self.step - shared.total_tqdm.update() - - return res - - def initialize(self, p): - self.eta = p.eta if p.eta is not None else opts.eta_ddim - - for fieldname in ['p_sample_ddim', 'p_sample_plms']: - if hasattr(self.sampler, fieldname): - setattr(self.sampler, fieldname, self.p_sample_ddim_hook) - - self.mask = p.mask if hasattr(p, 'mask') else None - self.nmask = p.nmask if hasattr(p, 'nmask') else None - - def adjust_steps_if_invalid(self, p, num_steps): - if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): - valid_step = 999 / (1000 // num_steps) - if valid_step == floor(valid_step): - return int(valid_step) + 1 - - return num_steps - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = setup_img2img_steps(p, steps) - steps = self.adjust_steps_if_invalid(p, steps) - self.initialize(p) - - self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) - x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) - - self.init_latent = x - self.last_latent = x - self.step = 0 - - # Wrap the conditioning models with additional image conditioning for inpainting model - if image_conditioning is not None: - conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - self.initialize(p) - - self.init_latent = None - self.last_latent = x - self.step = 0 - - steps = self.adjust_steps_if_invalid(p, steps or p.steps) - - # Wrap the conditioning models with additional image conditioning for inpainting model - # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape - if image_conditioning is not None: - conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]} - unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]} - - samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) - - return samples_ddim - - class CFGDenoiser(torch.nn.Module): def __init__(self, model): super().__init__() @@ -312,7 +112,7 @@ class CFGDenoiser(torch.nn.Module): def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): if state.interrupted or state.skipped: - raise InterruptedException + raise sd_samplers_common.InterruptedException conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) @@ -354,9 +154,9 @@ class CFGDenoiser(torch.nn.Module): devices.test_for_nans(x_out, "unet") if opts.live_preview_content == "Prompt": - store_latent(x_out[0:uncond.shape[0]]) + sd_samplers_common.store_latent(x_out[0:uncond.shape[0]]) elif opts.live_preview_content == "Negative prompt": - store_latent(x_out[-uncond.shape[0]:]) + sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) @@ -395,19 +195,6 @@ class TorchHijack: return torch.randn_like(x) -# MPS fix for randn in torchsde -def torchsde_randn(size, dtype, device, seed): - if device.type == 'mps': - generator = torch.Generator(devices.cpu).manual_seed(int(seed)) - return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device) - else: - generator = torch.Generator(device).manual_seed(int(seed)) - return torch.randn(size, dtype=dtype, device=device, generator=generator) - - -torchsde._brownian.brownian_interval._randn = torchsde_randn - - class KDiffusionSampler: def __init__(self, funcname, sd_model): denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser @@ -430,11 +217,11 @@ class KDiffusionSampler: step = d['i'] latent = d["denoised"] if opts.live_preview_content == "Combined": - store_latent(latent) + sd_samplers_common.store_latent(latent) self.last_latent = latent if self.stop_at is not None and step > self.stop_at: - raise InterruptedException + raise sd_samplers_common.InterruptedException state.sampling_step = step shared.total_tqdm.update() @@ -445,7 +232,7 @@ class KDiffusionSampler: try: return func() - except InterruptedException: + except sd_samplers_common.InterruptedException: return self.last_latent def number_of_needed_noises(self, p): @@ -492,7 +279,7 @@ class KDiffusionSampler: return sigmas def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = setup_img2img_steps(p, steps) + steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) sigmas = self.get_sigmas(p, steps) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index a7910b56..5b06e341 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -1,99 +1,15 @@ from collections import namedtuple, deque import numpy as np -from math import floor import torch -import tqdm from PIL import Image -import inspect -import k_diffusion.sampling import torchsde._brownian.brownian_interval -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms -from modules import prompt_parser, devices, processing, images, sd_vae_approx +from modules import devices, processing, images, sd_vae_approx -from modules.shared import opts, cmd_opts, state +from modules.shared import opts, state import modules.shared as shared -from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback - SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) -samplers_k_diffusion = [ - ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), - ('Euler', 'sample_euler', ['k_euler'], {}), - ('LMS', 'sample_lms', ['k_lms'], {}), - ('Heun', 'sample_heun', ['k_heun'], {}), - ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}), - ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), - ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), - ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}), - ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), - ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), - ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), - ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), - ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), - ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), - ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), - ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}), -] - -samplers_data_k_diffusion = [ - SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) - for label, funcname, aliases, options in samplers_k_diffusion - if hasattr(k_diffusion.sampling, funcname) -] - -all_samplers = [ - *samplers_data_k_diffusion, - SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), - SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), -] -all_samplers_map = {x.name: x for x in all_samplers} - -samplers = [] -samplers_for_img2img = [] -samplers_map = {} - - -def create_sampler(name, model): - if name is not None: - config = all_samplers_map.get(name, None) - else: - config = all_samplers[0] - - assert config is not None, f'bad sampler name: {name}' - - sampler = config.constructor(model) - sampler.config = config - - return sampler - - -def set_samplers(): - global samplers, samplers_for_img2img - - hidden = set(opts.hide_samplers) - hidden_img2img = set(opts.hide_samplers + ['PLMS']) - - samplers = [x for x in all_samplers if x.name not in hidden] - samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] - - samplers_map.clear() - for sampler in all_samplers: - samplers_map[sampler.name.lower()] = sampler.name - for alias in sampler.aliases: - samplers_map[alias.lower()] = sampler.name - - -set_samplers() - -sampler_extra_params = { - 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], -} - def setup_img2img_steps(p, steps=None): if opts.img2img_fix_steps or steps is not None: @@ -147,254 +63,6 @@ class InterruptedException(BaseException): pass -class VanillaStableDiffusionSampler: - def __init__(self, constructor, sd_model): - self.sampler = constructor(sd_model) - self.is_plms = hasattr(self.sampler, 'p_sample_plms') - self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim - self.mask = None - self.nmask = None - self.init_latent = None - self.sampler_noises = None - self.step = 0 - self.stop_at = None - self.eta = None - self.default_eta = 0.0 - self.config = None - self.last_latent = None - - self.conditioning_key = sd_model.model.conditioning_key - - def number_of_needed_noises(self, p): - return 0 - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except InterruptedException: - return self.last_latent - - def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): - if state.interrupted or state.skipped: - raise InterruptedException - - if self.stop_at is not None and self.step > self.stop_at: - raise InterruptedException - - # Have to unwrap the inpainting conditioning here to perform pre-processing - image_conditioning = None - if isinstance(cond, dict): - image_conditioning = cond["c_concat"][0] - cond = cond["c_crossattn"][0] - unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) - - assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' - cond = tensor - - # for DDIM, shapes must match, we can't just process cond and uncond independently; - # filling unconditional_conditioning with repeats of the last vector to match length is - # not 100% correct but should work well enough - if unconditional_conditioning.shape[1] < cond.shape[1]: - last_vector = unconditional_conditioning[:, -1:] - last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1]) - unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated]) - elif unconditional_conditioning.shape[1] > cond.shape[1]: - unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]] - - if self.mask is not None: - img_orig = self.sampler.model.q_sample(self.init_latent, ts) - x_dec = img_orig * self.mask + self.nmask * x_dec - - # Wrap the image conditioning back up since the DDIM code can accept the dict directly. - # Note that they need to be lists because it just concatenates them later. - if image_conditioning is not None: - cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) - - if self.mask is not None: - self.last_latent = self.init_latent * self.mask + self.nmask * res[1] - else: - self.last_latent = res[1] - - store_latent(self.last_latent) - - self.step += 1 - state.sampling_step = self.step - shared.total_tqdm.update() - - return res - - def initialize(self, p): - self.eta = p.eta if p.eta is not None else opts.eta_ddim - - for fieldname in ['p_sample_ddim', 'p_sample_plms']: - if hasattr(self.sampler, fieldname): - setattr(self.sampler, fieldname, self.p_sample_ddim_hook) - - self.mask = p.mask if hasattr(p, 'mask') else None - self.nmask = p.nmask if hasattr(p, 'nmask') else None - - def adjust_steps_if_invalid(self, p, num_steps): - if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): - valid_step = 999 / (1000 // num_steps) - if valid_step == floor(valid_step): - return int(valid_step) + 1 - - return num_steps - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = setup_img2img_steps(p, steps) - steps = self.adjust_steps_if_invalid(p, steps) - self.initialize(p) - - self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) - x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) - - self.init_latent = x - self.last_latent = x - self.step = 0 - - # Wrap the conditioning models with additional image conditioning for inpainting model - if image_conditioning is not None: - conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - self.initialize(p) - - self.init_latent = None - self.last_latent = x - self.step = 0 - - steps = self.adjust_steps_if_invalid(p, steps or p.steps) - - # Wrap the conditioning models with additional image conditioning for inpainting model - # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape - if image_conditioning is not None: - conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]} - unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]} - - samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) - - return samples_ddim - - -class CFGDenoiser(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - self.mask = None - self.nmask = None - self.init_latent = None - self.step = 0 - - def combine_denoised(self, x_out, conds_list, uncond, cond_scale): - denoised_uncond = x_out[-uncond.shape[0]:] - denoised = torch.clone(denoised_uncond) - - for i, conds in enumerate(conds_list): - for cond_index, weight in conds: - denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) - - return denoised - - def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): - if state.interrupted or state.skipped: - raise InterruptedException - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) - - batch_size = len(conds_list) - repeats = [len(conds_list[i]) for i in range(batch_size)] - - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) - cfg_denoiser_callback(denoiser_params) - x_in = denoiser_params.x - image_cond_in = denoiser_params.image_cond - sigma_in = denoiser_params.sigma - - if tensor.shape[1] == uncond.shape[1]: - cond_in = torch.cat([tensor, uncond]) - - if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) - else: - x_out = torch.zeros_like(x_in) - for batch_offset in range(0, x_out.shape[0], batch_size): - a = batch_offset - b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) - else: - x_out = torch.zeros_like(x_in) - batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size - for batch_offset in range(0, tensor.shape[0], batch_size): - a = batch_offset - b = min(a + batch_size, tensor.shape[0]) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]}) - - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) - - devices.test_for_nans(x_out, "unet") - - if opts.live_preview_content == "Prompt": - store_latent(x_out[0:uncond.shape[0]]) - elif opts.live_preview_content == "Negative prompt": - store_latent(x_out[-uncond.shape[0]:]) - - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) - - if self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised - - self.step += 1 - - return denoised - - -class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) - - def __getattr__(self, item): - if item == 'randn_like': - return self.randn_like - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - - def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise - - if x.device.type == 'mps': - return torch.randn_like(x, device=devices.cpu).to(x.device) - else: - return torch.randn_like(x) - - # MPS fix for randn in torchsde def torchsde_randn(size, dtype, device, seed): if device.type == 'mps': @@ -407,146 +75,3 @@ def torchsde_randn(size, dtype, device, seed): torchsde._brownian.brownian_interval._randn = torchsde_randn - -class KDiffusionSampler: - def __init__(self, funcname, sd_model): - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.funcname = funcname - self.func = getattr(k_diffusion.sampling, self.funcname) - self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) - self.sampler_noises = None - self.stop_at = None - self.eta = None - self.default_eta = 1.0 - self.config = None - self.last_latent = None - - self.conditioning_key = sd_model.model.conditioning_key - - def callback_state(self, d): - step = d['i'] - latent = d["denoised"] - if opts.live_preview_content == "Combined": - store_latent(latent) - self.last_latent = latent - - if self.stop_at is not None and step > self.stop_at: - raise InterruptedException - - state.sampling_step = step - shared.total_tqdm.update() - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except InterruptedException: - return self.last_latent - - def number_of_needed_noises(self, p): - return p.steps - - def initialize(self, p): - self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None - self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.step = 0 - self.eta = p.eta or opts.eta_ancestral - - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) - - extra_params_kwargs = {} - for param_name in self.extra_params: - if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: - extra_params_kwargs[param_name] = getattr(p, param_name) - - if 'eta' in inspect.signature(self.func).parameters: - extra_params_kwargs['eta'] = self.eta - - return extra_params_kwargs - - def get_sigmas(self, p, steps): - discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) - if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: - discard_next_to_last_sigma = True - p.extra_generation_params["Discard penultimate sigma"] = True - - steps += 1 if discard_next_to_last_sigma else 0 - - if p.sampler_noise_scheduler_override: - sigmas = p.sampler_noise_scheduler_override(steps) - elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': - sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) - else: - sigmas = self.model_wrap.get_sigmas(steps) - - if discard_next_to_last_sigma: - sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) - - return sigmas - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = setup_img2img_steps(p, steps) - - sigmas = self.get_sigmas(p, steps) - - sigma_sched = sigmas[steps - t_enc - 1:] - xi = x + noise * sigma_sched[0] - - extra_params_kwargs = self.initialize(p) - if 'sigma_min' in inspect.signature(self.func).parameters: - ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last - extra_params_kwargs['sigma_min'] = sigma_sched[-2] - if 'sigma_max' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_max'] = sigma_sched[0] - if 'n' in inspect.signature(self.func).parameters: - extra_params_kwargs['n'] = len(sigma_sched) - 1 - if 'sigma_sched' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_sched'] = sigma_sched - if 'sigmas' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigmas'] = sigma_sched - - self.model_wrap_cfg.init_latent = x - self.last_latent = x - - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None): - steps = steps or p.steps - - sigmas = self.get_sigmas(p, steps) - - x = x * sigmas[0] - - extra_params_kwargs = self.initialize(p) - if 'sigma_min' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() - extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() - if 'n' in inspect.signature(self.func).parameters: - extra_params_kwargs['n'] = steps - else: - extra_params_kwargs['sigmas'] = sigmas - - self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - return samples - diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index a7910b56..3d35ff72 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -1,150 +1,10 @@ -from collections import namedtuple, deque +import math + import numpy as np -from math import floor import torch -import tqdm -from PIL import Image -import inspect -import k_diffusion.sampling -import torchsde._brownian.brownian_interval -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms -from modules import prompt_parser, devices, processing, images, sd_vae_approx -from modules.shared import opts, cmd_opts, state -import modules.shared as shared -from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback - - -SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) - -samplers_k_diffusion = [ - ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), - ('Euler', 'sample_euler', ['k_euler'], {}), - ('LMS', 'sample_lms', ['k_lms'], {}), - ('Heun', 'sample_heun', ['k_heun'], {}), - ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}), - ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), - ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), - ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}), - ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), - ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), - ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), - ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), - ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), - ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), - ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), - ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}), -] - -samplers_data_k_diffusion = [ - SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) - for label, funcname, aliases, options in samplers_k_diffusion - if hasattr(k_diffusion.sampling, funcname) -] - -all_samplers = [ - *samplers_data_k_diffusion, - SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), - SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), -] -all_samplers_map = {x.name: x for x in all_samplers} - -samplers = [] -samplers_for_img2img = [] -samplers_map = {} - - -def create_sampler(name, model): - if name is not None: - config = all_samplers_map.get(name, None) - else: - config = all_samplers[0] - - assert config is not None, f'bad sampler name: {name}' - - sampler = config.constructor(model) - sampler.config = config - - return sampler - - -def set_samplers(): - global samplers, samplers_for_img2img - - hidden = set(opts.hide_samplers) - hidden_img2img = set(opts.hide_samplers + ['PLMS']) - - samplers = [x for x in all_samplers if x.name not in hidden] - samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] - - samplers_map.clear() - for sampler in all_samplers: - samplers_map[sampler.name.lower()] = sampler.name - for alias in sampler.aliases: - samplers_map[alias.lower()] = sampler.name - - -set_samplers() - -sampler_extra_params = { - 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], -} - - -def setup_img2img_steps(p, steps=None): - if opts.img2img_fix_steps or steps is not None: - requested_steps = (steps or p.steps) - steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 - t_enc = requested_steps - 1 - else: - steps = p.steps - t_enc = int(min(p.denoising_strength, 0.999) * steps) - - return steps, t_enc - - -approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2} - - -def single_sample_to_image(sample, approximation=None): - if approximation is None: - approximation = approximation_indexes.get(opts.show_progress_type, 0) - - if approximation == 2: - x_sample = sd_vae_approx.cheap_approximation(sample) - elif approximation == 1: - x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() - else: - x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] - - x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) - return Image.fromarray(x_sample) - - -def sample_to_image(samples, index=0, approximation=None): - return single_sample_to_image(samples[index], approximation) - - -def samples_to_image_grid(samples, approximation=None): - return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples]) - - -def store_latent(decoded): - state.current_latent = decoded - - if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: - if not shared.parallel_processing_allowed: - shared.state.assign_current_image(sample_to_image(decoded)) - - -class InterruptedException(BaseException): - pass +from modules.shared import state +from modules import sd_samplers_common, prompt_parser, shared class VanillaStableDiffusionSampler: @@ -174,15 +34,15 @@ class VanillaStableDiffusionSampler: try: return func() - except InterruptedException: + except sd_samplers_common.InterruptedException: return self.last_latent def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): if state.interrupted or state.skipped: - raise InterruptedException + raise sd_samplers_common.InterruptedException if self.stop_at is not None and self.step > self.stop_at: - raise InterruptedException + raise sd_samplers_common.InterruptedException # Have to unwrap the inpainting conditioning here to perform pre-processing image_conditioning = None @@ -224,7 +84,7 @@ class VanillaStableDiffusionSampler: else: self.last_latent = res[1] - store_latent(self.last_latent) + sd_samplers_common.store_latent(self.last_latent) self.step += 1 state.sampling_step = self.step @@ -233,7 +93,7 @@ class VanillaStableDiffusionSampler: return res def initialize(self, p): - self.eta = p.eta if p.eta is not None else opts.eta_ddim + self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim for fieldname in ['p_sample_ddim', 'p_sample_plms']: if hasattr(self.sampler, fieldname): @@ -245,13 +105,13 @@ class VanillaStableDiffusionSampler: def adjust_steps_if_invalid(self, p, num_steps): if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): valid_step = 999 / (1000 // num_steps) - if valid_step == floor(valid_step): + if valid_step == math.floor(valid_step): return int(valid_step) + 1 return num_steps def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = setup_img2img_steps(p, steps) + steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) steps = self.adjust_steps_if_invalid(p, steps) self.initialize(p) @@ -289,264 +149,3 @@ class VanillaStableDiffusionSampler: samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) return samples_ddim - - -class CFGDenoiser(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - self.mask = None - self.nmask = None - self.init_latent = None - self.step = 0 - - def combine_denoised(self, x_out, conds_list, uncond, cond_scale): - denoised_uncond = x_out[-uncond.shape[0]:] - denoised = torch.clone(denoised_uncond) - - for i, conds in enumerate(conds_list): - for cond_index, weight in conds: - denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) - - return denoised - - def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): - if state.interrupted or state.skipped: - raise InterruptedException - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) - - batch_size = len(conds_list) - repeats = [len(conds_list[i]) for i in range(batch_size)] - - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) - cfg_denoiser_callback(denoiser_params) - x_in = denoiser_params.x - image_cond_in = denoiser_params.image_cond - sigma_in = denoiser_params.sigma - - if tensor.shape[1] == uncond.shape[1]: - cond_in = torch.cat([tensor, uncond]) - - if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) - else: - x_out = torch.zeros_like(x_in) - for batch_offset in range(0, x_out.shape[0], batch_size): - a = batch_offset - b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) - else: - x_out = torch.zeros_like(x_in) - batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size - for batch_offset in range(0, tensor.shape[0], batch_size): - a = batch_offset - b = min(a + batch_size, tensor.shape[0]) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]}) - - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) - - devices.test_for_nans(x_out, "unet") - - if opts.live_preview_content == "Prompt": - store_latent(x_out[0:uncond.shape[0]]) - elif opts.live_preview_content == "Negative prompt": - store_latent(x_out[-uncond.shape[0]:]) - - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) - - if self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised - - self.step += 1 - - return denoised - - -class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) - - def __getattr__(self, item): - if item == 'randn_like': - return self.randn_like - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - - def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise - - if x.device.type == 'mps': - return torch.randn_like(x, device=devices.cpu).to(x.device) - else: - return torch.randn_like(x) - - -# MPS fix for randn in torchsde -def torchsde_randn(size, dtype, device, seed): - if device.type == 'mps': - generator = torch.Generator(devices.cpu).manual_seed(int(seed)) - return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device) - else: - generator = torch.Generator(device).manual_seed(int(seed)) - return torch.randn(size, dtype=dtype, device=device, generator=generator) - - -torchsde._brownian.brownian_interval._randn = torchsde_randn - - -class KDiffusionSampler: - def __init__(self, funcname, sd_model): - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.funcname = funcname - self.func = getattr(k_diffusion.sampling, self.funcname) - self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) - self.sampler_noises = None - self.stop_at = None - self.eta = None - self.default_eta = 1.0 - self.config = None - self.last_latent = None - - self.conditioning_key = sd_model.model.conditioning_key - - def callback_state(self, d): - step = d['i'] - latent = d["denoised"] - if opts.live_preview_content == "Combined": - store_latent(latent) - self.last_latent = latent - - if self.stop_at is not None and step > self.stop_at: - raise InterruptedException - - state.sampling_step = step - shared.total_tqdm.update() - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except InterruptedException: - return self.last_latent - - def number_of_needed_noises(self, p): - return p.steps - - def initialize(self, p): - self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None - self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.step = 0 - self.eta = p.eta or opts.eta_ancestral - - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) - - extra_params_kwargs = {} - for param_name in self.extra_params: - if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: - extra_params_kwargs[param_name] = getattr(p, param_name) - - if 'eta' in inspect.signature(self.func).parameters: - extra_params_kwargs['eta'] = self.eta - - return extra_params_kwargs - - def get_sigmas(self, p, steps): - discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) - if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: - discard_next_to_last_sigma = True - p.extra_generation_params["Discard penultimate sigma"] = True - - steps += 1 if discard_next_to_last_sigma else 0 - - if p.sampler_noise_scheduler_override: - sigmas = p.sampler_noise_scheduler_override(steps) - elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': - sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) - else: - sigmas = self.model_wrap.get_sigmas(steps) - - if discard_next_to_last_sigma: - sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) - - return sigmas - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = setup_img2img_steps(p, steps) - - sigmas = self.get_sigmas(p, steps) - - sigma_sched = sigmas[steps - t_enc - 1:] - xi = x + noise * sigma_sched[0] - - extra_params_kwargs = self.initialize(p) - if 'sigma_min' in inspect.signature(self.func).parameters: - ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last - extra_params_kwargs['sigma_min'] = sigma_sched[-2] - if 'sigma_max' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_max'] = sigma_sched[0] - if 'n' in inspect.signature(self.func).parameters: - extra_params_kwargs['n'] = len(sigma_sched) - 1 - if 'sigma_sched' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_sched'] = sigma_sched - if 'sigmas' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigmas'] = sigma_sched - - self.model_wrap_cfg.init_latent = x - self.last_latent = x - - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None): - steps = steps or p.steps - - sigmas = self.get_sigmas(p, steps) - - x = x * sigmas[0] - - extra_params_kwargs = self.initialize(p) - if 'sigma_min' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() - extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() - if 'n' in inspect.signature(self.func).parameters: - extra_params_kwargs['n'] = steps - else: - extra_params_kwargs['sigmas'] = sigmas - - self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - return samples - From f4d0538bf2f6430b145bb26a294b7f82b50f031a Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:51:23 +0300 Subject: [PATCH 40/46] Split history sd_samplers.py to sd_samplers_kdiffusion.py --- modules/{sd_samplers.py => sd_samplers_kdiffusion.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{sd_samplers.py => sd_samplers_kdiffusion.py} (100%) diff --git a/modules/sd_samplers.py b/modules/sd_samplers_kdiffusion.py similarity index 100% rename from modules/sd_samplers.py rename to modules/sd_samplers_kdiffusion.py From 2db8ed32cd71fab68169dcb1b49998917190e3c7 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:51:23 +0300 Subject: [PATCH 41/46] Split history sd_samplers.py to sd_samplers_kdiffusion.py --- modules/{sd_samplers.py => temp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{sd_samplers.py => temp} (100%) diff --git a/modules/sd_samplers.py b/modules/temp similarity index 100% rename from modules/sd_samplers.py rename to modules/temp From 274474105a5166a985a47508ffd0695db41623a5 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Mon, 30 Jan 2023 09:51:23 +0300 Subject: [PATCH 42/46] Split history sd_samplers.py to sd_samplers_kdiffusion.py --- modules/{temp => sd_samplers.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{temp => sd_samplers.py} (100%) diff --git a/modules/temp b/modules/sd_samplers.py similarity index 100% rename from modules/temp rename to modules/sd_samplers.py From 4df63d2d197f26181758b5108f003f225fe84874 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 10:11:30 +0300 Subject: [PATCH 43/46] split samplers into one more files for k-diffusion --- modules/sd_samplers.py | 302 +----------------------------- modules/sd_samplers_common.py | 3 +- modules/sd_samplers_compvis.py | 8 + modules/sd_samplers_kdiffusion.py | 57 +----- 4 files changed, 22 insertions(+), 348 deletions(-) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 9a29f1ae..28c2136f 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,49 +1,11 @@ -from collections import deque -import torch -import inspect -import k_diffusion.sampling -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms -from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_compvis - -from modules.shared import opts, state -import modules.shared as shared -from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback +from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared # imports for functions that previously were here and are used by other modules from modules.sd_samplers_common import samples_to_image_grid, sample_to_image - -samplers_k_diffusion = [ - ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), - ('Euler', 'sample_euler', ['k_euler'], {}), - ('LMS', 'sample_lms', ['k_lms'], {}), - ('Heun', 'sample_heun', ['k_heun'], {}), - ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}), - ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), - ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), - ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}), - ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), - ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), - ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), - ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), - ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), - ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), - ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), - ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}), -] - -samplers_data_k_diffusion = [ - sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) - for label, funcname, aliases, options in samplers_k_diffusion - if hasattr(k_diffusion.sampling, funcname) -] - all_samplers = [ - *samplers_data_k_diffusion, - sd_samplers_common.SamplerData('DDIM', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), - sd_samplers_common.SamplerData('PLMS', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), + *sd_samplers_kdiffusion.samplers_data_k_diffusion, + *sd_samplers_compvis.samplers_data_compvis, ] all_samplers_map = {x.name: x for x in all_samplers} @@ -69,8 +31,8 @@ def create_sampler(name, model): def set_samplers(): global samplers, samplers_for_img2img - hidden = set(opts.hide_samplers) - hidden_img2img = set(opts.hide_samplers + ['PLMS']) + hidden = set(shared.opts.hide_samplers) + hidden_img2img = set(shared.opts.hide_samplers + ['PLMS']) samplers = [x for x in all_samplers if x.name not in hidden] samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] @@ -83,257 +45,3 @@ def set_samplers(): set_samplers() - -sampler_extra_params = { - 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], -} - - -class CFGDenoiser(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - self.mask = None - self.nmask = None - self.init_latent = None - self.step = 0 - - def combine_denoised(self, x_out, conds_list, uncond, cond_scale): - denoised_uncond = x_out[-uncond.shape[0]:] - denoised = torch.clone(denoised_uncond) - - for i, conds in enumerate(conds_list): - for cond_index, weight in conds: - denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) - - return denoised - - def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): - if state.interrupted or state.skipped: - raise sd_samplers_common.InterruptedException - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) - - batch_size = len(conds_list) - repeats = [len(conds_list[i]) for i in range(batch_size)] - - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) - cfg_denoiser_callback(denoiser_params) - x_in = denoiser_params.x - image_cond_in = denoiser_params.image_cond - sigma_in = denoiser_params.sigma - - if tensor.shape[1] == uncond.shape[1]: - cond_in = torch.cat([tensor, uncond]) - - if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) - else: - x_out = torch.zeros_like(x_in) - for batch_offset in range(0, x_out.shape[0], batch_size): - a = batch_offset - b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) - else: - x_out = torch.zeros_like(x_in) - batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size - for batch_offset in range(0, tensor.shape[0], batch_size): - a = batch_offset - b = min(a + batch_size, tensor.shape[0]) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]}) - - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) - - devices.test_for_nans(x_out, "unet") - - if opts.live_preview_content == "Prompt": - sd_samplers_common.store_latent(x_out[0:uncond.shape[0]]) - elif opts.live_preview_content == "Negative prompt": - sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) - - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) - - if self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised - - self.step += 1 - - return denoised - - -class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) - - def __getattr__(self, item): - if item == 'randn_like': - return self.randn_like - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - - def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise - - if x.device.type == 'mps': - return torch.randn_like(x, device=devices.cpu).to(x.device) - else: - return torch.randn_like(x) - - -class KDiffusionSampler: - def __init__(self, funcname, sd_model): - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.funcname = funcname - self.func = getattr(k_diffusion.sampling, self.funcname) - self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) - self.sampler_noises = None - self.stop_at = None - self.eta = None - self.default_eta = 1.0 - self.config = None - self.last_latent = None - - self.conditioning_key = sd_model.model.conditioning_key - - def callback_state(self, d): - step = d['i'] - latent = d["denoised"] - if opts.live_preview_content == "Combined": - sd_samplers_common.store_latent(latent) - self.last_latent = latent - - if self.stop_at is not None and step > self.stop_at: - raise sd_samplers_common.InterruptedException - - state.sampling_step = step - shared.total_tqdm.update() - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except sd_samplers_common.InterruptedException: - return self.last_latent - - def number_of_needed_noises(self, p): - return p.steps - - def initialize(self, p): - self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None - self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.step = 0 - self.eta = p.eta or opts.eta_ancestral - - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) - - extra_params_kwargs = {} - for param_name in self.extra_params: - if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: - extra_params_kwargs[param_name] = getattr(p, param_name) - - if 'eta' in inspect.signature(self.func).parameters: - extra_params_kwargs['eta'] = self.eta - - return extra_params_kwargs - - def get_sigmas(self, p, steps): - discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) - if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: - discard_next_to_last_sigma = True - p.extra_generation_params["Discard penultimate sigma"] = True - - steps += 1 if discard_next_to_last_sigma else 0 - - if p.sampler_noise_scheduler_override: - sigmas = p.sampler_noise_scheduler_override(steps) - elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': - sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) - else: - sigmas = self.model_wrap.get_sigmas(steps) - - if discard_next_to_last_sigma: - sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) - - return sigmas - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) - - sigmas = self.get_sigmas(p, steps) - - sigma_sched = sigmas[steps - t_enc - 1:] - xi = x + noise * sigma_sched[0] - - extra_params_kwargs = self.initialize(p) - if 'sigma_min' in inspect.signature(self.func).parameters: - ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last - extra_params_kwargs['sigma_min'] = sigma_sched[-2] - if 'sigma_max' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_max'] = sigma_sched[0] - if 'n' in inspect.signature(self.func).parameters: - extra_params_kwargs['n'] = len(sigma_sched) - 1 - if 'sigma_sched' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_sched'] = sigma_sched - if 'sigmas' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigmas'] = sigma_sched - - self.model_wrap_cfg.init_latent = x - self.last_latent = x - - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None): - steps = steps or p.steps - - sigmas = self.get_sigmas(p, steps) - - x = x * sigmas[0] - - extra_params_kwargs = self.initialize(p) - if 'sigma_min' in inspect.signature(self.func).parameters: - extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() - extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() - if 'n' in inspect.signature(self.func).parameters: - extra_params_kwargs['n'] = steps - else: - extra_params_kwargs['sigmas'] = sigmas - - self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - return samples - diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 5b06e341..3c03d442 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -1,4 +1,4 @@ -from collections import namedtuple, deque +from collections import namedtuple import numpy as np import torch from PIL import Image @@ -64,6 +64,7 @@ class InterruptedException(BaseException): # MPS fix for randn in torchsde +# XXX move this to separate file for MPS def torchsde_randn(size, dtype, device, seed): if device.type == 'mps': generator = torch.Generator(devices.cpu).manual_seed(int(seed)) diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 3d35ff72..88541193 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -1,4 +1,6 @@ import math +import ldm.models.diffusion.ddim +import ldm.models.diffusion.plms import numpy as np import torch @@ -7,6 +9,12 @@ from modules.shared import state from modules import sd_samplers_common, prompt_parser, shared +samplers_data_compvis = [ + sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), + sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), +] + + class VanillaStableDiffusionSampler: def __init__(self, constructor, sd_model): self.sampler = constructor(sd_model) diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 9a29f1ae..adb6883e 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,18 +2,12 @@ from collections import deque import torch import inspect import k_diffusion.sampling -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_compvis from modules.shared import opts, state import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback -# imports for functions that previously were here and are used by other modules -from modules.sd_samplers_common import samples_to_image_grid, sample_to_image - - samplers_k_diffusion = [ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), ('Euler', 'sample_euler', ['k_euler'], {}), @@ -40,50 +34,6 @@ samplers_data_k_diffusion = [ if hasattr(k_diffusion.sampling, funcname) ] -all_samplers = [ - *samplers_data_k_diffusion, - sd_samplers_common.SamplerData('DDIM', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), - sd_samplers_common.SamplerData('PLMS', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), -] -all_samplers_map = {x.name: x for x in all_samplers} - -samplers = [] -samplers_for_img2img = [] -samplers_map = {} - - -def create_sampler(name, model): - if name is not None: - config = all_samplers_map.get(name, None) - else: - config = all_samplers[0] - - assert config is not None, f'bad sampler name: {name}' - - sampler = config.constructor(model) - sampler.config = config - - return sampler - - -def set_samplers(): - global samplers, samplers_for_img2img - - hidden = set(opts.hide_samplers) - hidden_img2img = set(opts.hide_samplers + ['PLMS']) - - samplers = [x for x in all_samplers if x.name not in hidden] - samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] - - samplers_map.clear() - for sampler in all_samplers: - samplers_map[sampler.name.lower()] = sampler.name - for alias in sampler.aliases: - samplers_map[alias.lower()] = sampler.name - - -set_samplers() - sampler_extra_params = { 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], @@ -92,6 +42,13 @@ sampler_extra_params = { class CFGDenoiser(torch.nn.Module): + """ + Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) + that can take a noisy picture and produce a noise-free picture using two guidances (prompts) + instead of one. Originally, the second prompt is just an empty string, but we use non-empty + negative prompt. + """ + def __init__(self, model): super().__init__() self.inner_model = model From 040ec7a80e23d340efe1108b9de5ead62d9011a9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 10:47:09 +0300 Subject: [PATCH 44/46] make the program read Eta and Eta DDIM from generation parameters --- modules/generation_parameters_copypaste.py | 2 ++ modules/processing.py | 1 - modules/sd_samplers_compvis.py | 3 ++- modules/sd_samplers_kdiffusion.py | 8 +++++--- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 2a10524f..7ee8ee10 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -293,6 +293,8 @@ infotext_to_setting_name_mapping = [ ('Model hash', 'sd_model_checkpoint'), ('ENSD', 'eta_noise_seed_delta'), ('Noise multiplier', 'initial_noise_multiplier'), + ('Eta', 'eta_ancestral'), + ('Eta DDIM', 'eta_ddim'), ] diff --git a/modules/processing.py b/modules/processing.py index 2d295932..e544c2e1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -455,7 +455,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, - "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Clip skip": None if clip_skip <= 1 else clip_skip, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, } diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 88541193..d03131cd 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -27,7 +27,6 @@ class VanillaStableDiffusionSampler: self.step = 0 self.stop_at = None self.eta = None - self.default_eta = 0.0 self.config = None self.last_latent = None @@ -102,6 +101,8 @@ class VanillaStableDiffusionSampler: def initialize(self, p): self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim + if self.eta != 0.0: + p.extra_generation_params["Eta DDIM"] = self.eta for fieldname in ['p_sample_ddim', 'p_sample_plms']: if hasattr(self.sampler, fieldname): diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index adb6883e..aa7f106b 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,7 +2,7 @@ from collections import deque import torch import inspect import k_diffusion.sampling -from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_compvis +from modules import prompt_parser, devices, sd_samplers_common from modules.shared import opts, state import modules.shared as shared @@ -164,7 +164,6 @@ class KDiffusionSampler: self.sampler_noises = None self.stop_at = None self.eta = None - self.default_eta = 1.0 self.config = None self.last_latent = None @@ -199,7 +198,7 @@ class KDiffusionSampler: self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.step = 0 - self.eta = p.eta or opts.eta_ancestral + self.eta = p.eta if p.eta is not None else opts.eta_ancestral k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) @@ -209,6 +208,9 @@ class KDiffusionSampler: extra_params_kwargs[param_name] = getattr(p, param_name) if 'eta' in inspect.signature(self.func).parameters: + if self.eta != 1.0: + p.extra_generation_params["Eta"] = self.eta + extra_params_kwargs['eta'] = self.eta return extra_params_kwargs From ab059b6e4863eaa5e118a2043192584e6df51ed4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 10:52:15 +0300 Subject: [PATCH 45/46] make the program read Discard penultimate sigma from generation parameters --- modules/generation_parameters_copypaste.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 7ee8ee10..fc9e17aa 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -295,6 +295,7 @@ infotext_to_setting_name_mapping = [ ('Noise multiplier', 'initial_noise_multiplier'), ('Eta', 'eta_ancestral'), ('Eta DDIM', 'eta_ddim'), + ('Discard penultimate sigma', 'always_discard_next_to_last_sigma') ] From aa4688eb8345de583070ca9ddb4c6f585f06762b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 30 Jan 2023 13:29:44 +0300 Subject: [PATCH 46/46] disable EMA weights for instructpix2pix model, whcih should get memory usage as well as image quality to what it was before d2ac95fa7b2a8d0bcc5361ee16dba9cbb81ff8b2 --- configs/instruct-pix2pix.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/configs/instruct-pix2pix.yaml b/configs/instruct-pix2pix.yaml index 437ddcef..4e896879 100644 --- a/configs/instruct-pix2pix.yaml +++ b/configs/instruct-pix2pix.yaml @@ -20,8 +20,7 @@ model: conditioning_key: hybrid monitor: val/loss_simple_ema scale_factor: 0.18215 - use_ema: true - load_ema: true + use_ema: false scheduler_config: # 10000 warmup steps target: ldm.lr_scheduler.LambdaLinearScheduler