Merge pull request #8772 from mcmonkey4eva/img2img-alt-sd2-fix
Fix img2img-alternative-test script for SD v2.x
This commit is contained in:
commit
983d48a921
@ -6,23 +6,21 @@ from tqdm import trange
|
|||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import processing, shared, sd_samplers, prompt_parser, sd_samplers_common
|
from modules import processing, shared, sd_samplers, sd_samplers_common
|
||||||
from modules.processing import Processed
|
|
||||||
from modules.shared import opts, cmd_opts, state
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import k_diffusion as K
|
import k_diffusion as K
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
from torch import autocast
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
|
|
||||||
|
|
||||||
def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
|
def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
|
||||||
x = p.init_latent
|
x = p.init_latent
|
||||||
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
dnw = K.external.CompVisDenoiser(shared.sd_model)
|
if shared.sd_model.parameterization == "v":
|
||||||
|
dnw = K.external.CompVisVDenoiser(shared.sd_model)
|
||||||
|
skip = 1
|
||||||
|
else:
|
||||||
|
dnw = K.external.CompVisDenoiser(shared.sd_model)
|
||||||
|
skip = 0
|
||||||
sigmas = dnw.get_sigmas(steps).flip(0)
|
sigmas = dnw.get_sigmas(steps).flip(0)
|
||||||
|
|
||||||
shared.state.sampling_steps = steps
|
shared.state.sampling_steps = steps
|
||||||
@ -37,7 +35,7 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
|
|||||||
image_conditioning = torch.cat([p.image_conditioning] * 2)
|
image_conditioning = torch.cat([p.image_conditioning] * 2)
|
||||||
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
|
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
|
||||||
|
|
||||||
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
|
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
|
||||||
t = dnw.sigma_to_t(sigma_in)
|
t = dnw.sigma_to_t(sigma_in)
|
||||||
|
|
||||||
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
|
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
|
||||||
@ -69,7 +67,12 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
|
|||||||
x = p.init_latent
|
x = p.init_latent
|
||||||
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
dnw = K.external.CompVisDenoiser(shared.sd_model)
|
if shared.sd_model.parameterization == "v":
|
||||||
|
dnw = K.external.CompVisVDenoiser(shared.sd_model)
|
||||||
|
skip = 1
|
||||||
|
else:
|
||||||
|
dnw = K.external.CompVisDenoiser(shared.sd_model)
|
||||||
|
skip = 0
|
||||||
sigmas = dnw.get_sigmas(steps).flip(0)
|
sigmas = dnw.get_sigmas(steps).flip(0)
|
||||||
|
|
||||||
shared.state.sampling_steps = steps
|
shared.state.sampling_steps = steps
|
||||||
@ -84,7 +87,7 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
|
|||||||
image_conditioning = torch.cat([p.image_conditioning] * 2)
|
image_conditioning = torch.cat([p.image_conditioning] * 2)
|
||||||
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
|
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
|
||||||
|
|
||||||
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
|
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
|
||||||
|
|
||||||
if i == 1:
|
if i == 1:
|
||||||
t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
|
t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
|
||||||
@ -125,7 +128,7 @@ class Script(scripts.Script):
|
|||||||
def show(self, is_img2img):
|
def show(self, is_img2img):
|
||||||
return is_img2img
|
return is_img2img
|
||||||
|
|
||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
info = gr.Markdown('''
|
info = gr.Markdown('''
|
||||||
* `CFG Scale` should be 2 or lower.
|
* `CFG Scale` should be 2 or lower.
|
||||||
''')
|
''')
|
||||||
@ -213,4 +216,3 @@ class Script(scripts.Script):
|
|||||||
processed = processing.process_images(p)
|
processed = processing.process_images(p)
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user