import contextlib import json import math import os import sys import torch import numpy as np from PIL import Image, ImageFilter, ImageOps import random import modules.sd_hijack from modules import devices from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.face_restoration import modules.images as images import modules.styles # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 opt_f = 8 class StableDiffusionProcessing: def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): self.sd_model = sd_model self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids self.prompt: str = prompt self.prompt_for_display: str = None self.negative_prompt: str = (negative_prompt or "") self.prompt_style: str = prompt_style self.seed: int = seed self.subseed: int = subseed self.subseed_strength: float = subseed_strength self.seed_resize_from_h: int = seed_resize_from_h self.seed_resize_from_w: int = seed_resize_from_w self.sampler_index: int = sampler_index self.batch_size: int = batch_size self.n_iter: int = n_iter self.steps: int = steps self.cfg_scale: float = cfg_scale self.width: int = width self.height: int = height self.restore_faces: bool = restore_faces self.tiling: bool = tiling self.do_not_save_samples: bool = do_not_save_samples self.do_not_save_grid: bool = do_not_save_grid self.extra_generation_params: dict = extra_generation_params self.overlay_images = overlay_images self.paste_to = None def init(self, seed): pass def sample(self, x, conditioning, unconditional_conditioning): raise NotImplementedError() class Processed: def __init__(self, p: StableDiffusionProcessing, images_list, seed, info): self.images = images_list self.prompt = p.prompt self.negative_prompt = p.negative_prompt self.seed = seed self.info = info self.width = p.width self.height = p.height self.sampler = samplers[p.sampler_index].name self.cfg_scale = p.cfg_scale self.steps = p.steps def js(self): obj = { "prompt": self.prompt if type(self.prompt) != list else self.prompt[0], "negative_prompt": self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0], "seed": int(self.seed if type(self.seed) != list else self.seed[0]), "width": self.width, "height": self.height, "sampler": self.sampler, "cfg_scale": self.cfg_scale, "steps": self.steps, } return json.dumps(obj) # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 def slerp(val, low, high): low_norm = low/torch.norm(low, dim=1, keepdim=True) high_norm = high/torch.norm(high, dim=1, keepdim=True) omega = torch.acos((low_norm*high_norm).sum(1)) so = torch.sin(omega) res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high return res def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0): xs = [] for i, seed in enumerate(seeds): noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8) subnoise = None if subseeds is not None: subseed = 0 if i >= len(subseeds) else subseeds[i] subnoise = devices.randn(subseed, noise_shape) # randn results depend on device; gpu and cpu get different results for same seed; # the way I see it, it's better to do this on CPU, so that everyone gets same result; # but the original script had it like this, so I do not dare change it for now because # it will break everyone's seeds. noise = devices.randn(seed, noise_shape) if subnoise is not None: #noise = subnoise * subseed_strength + noise * (1 - subseed_strength) noise = slerp(subseed_strength, noise, subnoise) if noise_shape != shape: #noise = torch.nn.functional.interpolate(noise.unsqueeze(1), size=shape[1:], mode="bilinear").squeeze() x = devices.randn(seed, shape) dx = (shape[2] - noise_shape[2]) // 2 dy = (shape[1] - noise_shape[1]) // 2 w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy tx = 0 if dx < 0 else dx ty = 0 if dy < 0 else dy dx = max(-dx, 0) dy = max(-dy, 0) x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w] noise = x xs.append(noise) x = torch.stack(xs).to(shared.device) return x def fix_seed(p): p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == -1 else p.seed p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed def process_images(p: StableDiffusionProcessing) -> Processed: """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" assert p.prompt is not None devices.torch_gc() fix_seed(p) os.makedirs(p.outpath_samples, exist_ok=True) os.makedirs(p.outpath_grids, exist_ok=True) modules.sd_hijack.model_hijack.apply_circular(p.tiling) comments = [] modules.styles.apply_style(p, shared.prompt_styles[p.prompt_style]) if type(p.prompt) == list: all_prompts = p.prompt else: all_prompts = p.batch_size * p.n_iter * [p.prompt] if type(p.seed) == list: all_seeds = p.seed else: all_seeds = [int(p.seed + (x if p.subseed_strength == 0 else 0)) for x in range(len(all_prompts))] if type(p.subseed) == list: all_subseeds = p.subseed else: all_subseeds = [int(p.subseed + x) for x in range(len(all_prompts))] def infotext(iteration=0, position_in_batch=0): index = position_in_batch + iteration * p.batch_size generation_params = { "Steps": p.steps, "Sampler": samplers[p.sampler_index].name, "CFG scale": p.cfg_scale, "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", "Model hash": (None if not opts.add_model_hash_to_info or not shared.sd_model_hash else shared.sd_model_hash), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), } if p.extra_generation_params is not None: generation_params.update(p.extra_generation_params) generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None]) negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments]) if os.path.exists(cmd_opts.embeddings_dir): model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model) output_images = [] precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope) with torch.no_grad(), precision_scope("cuda"), ema_scope(): p.init(seed=all_seeds[0]) if state.job_count == -1: state.job_count = p.n_iter for n in range(p.n_iter): if state.interrupted: break prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size] seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) c = p.sd_model.get_learned_conditioning(prompts) if len(model_hijack.comments) > 0: comments += model_hijack.comments # we manually generate all input noises because each one should have a specific seed x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc) if state.interrupted: # if we are interruped, sample returns just noise # use the image collected previously in sampler loop samples_ddim = shared.state.current_latent x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) if opts.filter_nsfw: import modules.safety as safety x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) for i, x_sample in enumerate(x_samples_ddim): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) if p.restore_faces: if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) devices.torch_gc() x_sample = modules.face_restoration.restore_faces(x_sample) image = Image.fromarray(x_sample) if p.overlay_images is not None and i < len(p.overlay_images): overlay = p.overlay_images[i] if p.paste_to is not None: x, y, w, h = p.paste_to base_image = Image.new('RGBA', (overlay.width, overlay.height)) image = images.resize_image(1, image, w, h) base_image.paste(image, (x, y)) image = base_image image = image.convert('RGBA') image.alpha_composite(overlay) image = image.convert('RGB') if opts.samples_save and not p.do_not_save_samples: images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) output_images.append(image) state.nextjob() unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple if not p.do_not_save_grid and not unwanted_grid_because_of_img_count: return_grid = opts.return_grid grid = images.image_grid(output_images, p.batch_size) if return_grid: output_images.insert(0, grid) if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p) devices.torch_gc() return Processed(p, output_images, all_seeds[0], infotext()) class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None def init(self, seed): self.sampler = samplers[self.sampler_index].constructor(self.sd_model) def sample(self, x, conditioning, unconditional_conditioning): samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning) return samples_ddim def get_crop_region(mask, pad=0): h, w = mask.shape crop_left = 0 for i in range(w): if not (mask[:, i] == 0).all(): break crop_left += 1 crop_right = 0 for i in reversed(range(w)): if not (mask[:, i] == 0).all(): break crop_right += 1 crop_top = 0 for i in range(h): if not (mask[i] == 0).all(): break crop_top += 1 crop_bottom = 0 for i in reversed(range(h)): if not (mask[i] == 0).all(): break crop_bottom += 1 return ( int(max(crop_left-pad, 0)), int(max(crop_top-pad, 0)), int(min(w - crop_right + pad, w)), int(min(h - crop_bottom + pad, h)) ) def fill(image, mask): image_mod = Image.new('RGBA', (image.width, image.height)) image_masked = Image.new('RGBa', (image.width, image.height)) image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L'))) image_masked = image_masked.convert('RGBa') for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]: blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA') for _ in range(repeats): image_mod.alpha_composite(blurred) return image_mod.convert("RGB") class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): sampler = None def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpainting_mask_invert=0, **kwargs): super().__init__(**kwargs) self.init_images = init_images self.resize_mode: int = resize_mode self.denoising_strength: float = denoising_strength self.init_latent = None self.image_mask = mask #self.image_unblurred_mask = None self.latent_mask = None self.mask_for_overlay = None self.mask_blur = mask_blur self.inpainting_fill = inpainting_fill self.inpaint_full_res = inpaint_full_res self.inpainting_mask_invert = inpainting_mask_invert self.mask = None self.nmask = None def init(self, seed): self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model) crop_region = None if self.image_mask is not None: self.image_mask = self.image_mask.convert('L') if self.inpainting_mask_invert: self.image_mask = ImageOps.invert(self.image_mask) #self.image_unblurred_mask = self.image_mask if self.mask_blur > 0: self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)) if self.inpaint_full_res: self.mask_for_overlay = self.image_mask mask = self.image_mask.convert('L') crop_region = get_crop_region(np.array(mask), opts.upscale_at_full_resolution_padding) x1, y1, x2, y2 = crop_region mask = mask.crop(crop_region) self.image_mask = images.resize_image(2, mask, self.width, self.height) self.paste_to = (x1, y1, x2-x1, y2-y1) else: self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height) np_mask = np.array(self.image_mask) np_mask = np.clip((np_mask.astype(np.float)) * 2, 0, 255).astype(np.uint8) self.mask_for_overlay = Image.fromarray(np_mask) self.overlay_images = [] latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask imgs = [] for img in self.init_images: image = img.convert("RGB") if crop_region is None: image = images.resize_image(self.resize_mode, image, self.width, self.height) if self.image_mask is not None: image_masked = Image.new('RGBa', (image.width, image.height)) image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) self.overlay_images.append(image_masked.convert('RGBA')) if crop_region is not None: image = image.crop(crop_region) image = images.resize_image(2, image, self.width, self.height) if self.image_mask is not None: if self.inpainting_fill != 1: image = fill(image, latent_mask) image = np.array(image).astype(np.float32) / 255.0 image = np.moveaxis(image, 2, 0) imgs.append(image) if len(imgs) == 1: batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0) if self.overlay_images is not None: self.overlay_images = self.overlay_images * self.batch_size elif len(imgs) <= self.batch_size: self.batch_size = len(imgs) batch_images = np.array(imgs) else: raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") image = torch.from_numpy(batch_images) image = 2. * image - 1. image = image.to(shared.device) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) if self.image_mask is not None: init_mask = latent_mask latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = latmask[0] latmask = np.around(latmask) latmask = np.tile(latmask[None], (4, 1, 1)) self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype) if self.inpainting_fill == 2: self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], [seed + x + 1 for x in range(self.init_latent.shape[0])]) * self.nmask elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask def sample(self, x, conditioning, unconditional_conditioning): samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning) if self.mask is not None: samples = samples * self.nmask + self.init_latent * self.mask return samples