extra samplers from K-diffusion
This commit is contained in:
parent
91dc8710ec
commit
c9579b51a6
55
webui.py
55
webui.py
@ -1,4 +1,6 @@
|
|||||||
import argparse, os, sys, glob
|
import argparse, os, sys, glob
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -16,7 +18,7 @@ import time
|
|||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import k_diffusion as K
|
import k_diffusion.sampling
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
@ -60,6 +62,19 @@ css_hide_progressbar = """
|
|||||||
.meta-text { display:none!important; }
|
.meta-text { display:none!important; }
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
|
||||||
|
samplers = [
|
||||||
|
*[SamplerData(x[0], lambda model: KDiffusionSampler(model, x[1])) for x in [
|
||||||
|
('LMS', 'sample_lms'),
|
||||||
|
('Heun', 'sample_heun'),
|
||||||
|
('Euler', 'sample_euler'),
|
||||||
|
('Euler ancestral', 'sample_euler_ancestral'),
|
||||||
|
('DPM 2', 'sample_dpm_2'),
|
||||||
|
('DPM 2 Ancestral', 'sample_dpm_2_ancestral'),
|
||||||
|
] if hasattr(k_diffusion.sampling, x[1])],
|
||||||
|
SamplerData('DDIM', lambda model: DDIMSampler(model)),
|
||||||
|
SamplerData('PLMS', lambda model: PLMSSampler(model)),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Options:
|
class Options:
|
||||||
@ -142,16 +157,18 @@ class CFGDenoiser(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class KDiffusionSampler:
|
class KDiffusionSampler:
|
||||||
def __init__(self, m):
|
def __init__(self, m, funcname):
|
||||||
self.model = m
|
self.model = m
|
||||||
self.model_wrap = K.external.CompVisDenoiser(m)
|
self.model_wrap = k_diffusion.external.CompVisDenoiser(m)
|
||||||
|
self.funcname = funcname
|
||||||
|
|
||||||
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
|
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
|
||||||
sigmas = self.model_wrap.get_sigmas(S)
|
sigmas = self.model_wrap.get_sigmas(S)
|
||||||
x = x_T * sigmas[0]
|
x = x_T * sigmas[0]
|
||||||
model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||||
|
|
||||||
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
|
fun = getattr(k_diffusion.sampling, self.funcname)
|
||||||
|
samples_ddim = fun(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
|
||||||
|
|
||||||
return samples_ddim, None
|
return samples_ddim, None
|
||||||
|
|
||||||
@ -526,7 +543,7 @@ def get_learned_conditioning_with_embeddings(model, prompts):
|
|||||||
return model.get_learned_conditioning(prompts)
|
return model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
|
|
||||||
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False):
|
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False):
|
||||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||||
|
|
||||||
assert prompt is not None
|
assert prompt is not None
|
||||||
@ -579,7 +596,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
|||||||
def infotext():
|
def infotext():
|
||||||
return f"""
|
return f"""
|
||||||
{prompt}
|
{prompt}
|
||||||
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
|
Steps: {steps}, Sampler: {samplers[sampler_index].name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
|
||||||
""".strip() + "".join(["\n\n" + x for x in comments])
|
""".strip() + "".join(["\n\n" + x for x in comments])
|
||||||
|
|
||||||
if os.path.exists(cmd_opts.embeddings_dir):
|
if os.path.exists(cmd_opts.embeddings_dir):
|
||||||
@ -645,17 +662,10 @@ Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{',
|
|||||||
return output_images, seed, infotext()
|
return output_images, seed, infotext()
|
||||||
|
|
||||||
|
|
||||||
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
|
def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
|
||||||
outpath = opts.outdir or "outputs/txt2img-samples"
|
outpath = opts.outdir or "outputs/txt2img-samples"
|
||||||
|
|
||||||
if sampler_name == 'PLMS':
|
sampler = samplers[sampler_index].constructor(model)
|
||||||
sampler = PLMSSampler(model)
|
|
||||||
elif sampler_name == 'DDIM':
|
|
||||||
sampler = DDIMSampler(model)
|
|
||||||
elif sampler_name == 'k-diffusion':
|
|
||||||
sampler = KDiffusionSampler(model)
|
|
||||||
else:
|
|
||||||
raise Exception("Unknown sampler: " + sampler_name)
|
|
||||||
|
|
||||||
def init():
|
def init():
|
||||||
pass
|
pass
|
||||||
@ -670,7 +680,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, p
|
|||||||
func_sample=sample,
|
func_sample=sample,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
sampler_name=sampler_name,
|
sampler_index=sampler_index,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
steps=ddim_steps,
|
steps=ddim_steps,
|
||||||
@ -732,7 +742,7 @@ txt2img_interface = gr.Interface(
|
|||||||
inputs=[
|
inputs=[
|
||||||
gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
|
gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
|
||||||
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
|
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
|
||||||
gr.Radio(label='Sampling method', choices=["DDIM", "PLMS", "k-diffusion"], value="k-diffusion"),
|
gr.Radio(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index"),
|
||||||
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
||||||
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
|
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
|
||||||
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
|
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
|
||||||
@ -756,7 +766,7 @@ txt2img_interface = gr.Interface(
|
|||||||
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
|
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
|
||||||
outpath = opts.outdir or "outputs/img2img-samples"
|
outpath = opts.outdir or "outputs/img2img-samples"
|
||||||
|
|
||||||
sampler = KDiffusionSampler(model)
|
sampler = KDiffusionSampler(model, 'sample_lms')
|
||||||
|
|
||||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
|
|
||||||
@ -785,7 +795,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
|
|||||||
xi = x0 + noise
|
xi = x0 + noise
|
||||||
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
|
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
|
||||||
model_wrap_cfg = CFGDenoiser(sampler.model_wrap)
|
model_wrap_cfg = CFGDenoiser(sampler.model_wrap)
|
||||||
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
|
samples_ddim = k_diffusion.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
|
||||||
return samples_ddim
|
return samples_ddim
|
||||||
|
|
||||||
if loopback:
|
if loopback:
|
||||||
@ -800,7 +810,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
|
|||||||
func_sample=sample,
|
func_sample=sample,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
sampler_name='k-diffusion',
|
sampler_index=0,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
n_iter=1,
|
n_iter=1,
|
||||||
steps=ddim_steps,
|
steps=ddim_steps,
|
||||||
@ -835,7 +845,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
|
|||||||
func_sample=sample,
|
func_sample=sample,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
sampler_name='k-diffusion',
|
sampler_index=0,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
steps=ddim_steps,
|
steps=ddim_steps,
|
||||||
@ -877,10 +887,10 @@ img2img_interface = gr.Interface(
|
|||||||
gr.Number(label='Seed'),
|
gr.Number(label='Seed'),
|
||||||
gr.HTML(),
|
gr.HTML(),
|
||||||
],
|
],
|
||||||
title="Stable Diffusion Image-to-Image",
|
|
||||||
allow_flagging="never",
|
allow_flagging="never",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_GFPGAN(image, strength):
|
def run_GFPGAN(image, strength):
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
|
||||||
@ -904,7 +914,6 @@ gfpgan_interface = gr.Interface(
|
|||||||
gr.Number(label='Seed', visible=False),
|
gr.Number(label='Seed', visible=False),
|
||||||
gr.HTML(),
|
gr.HTML(),
|
||||||
],
|
],
|
||||||
title="GFPGAN",
|
|
||||||
description="Fix faces on images",
|
description="Fix faces on images",
|
||||||
allow_flagging="never",
|
allow_flagging="never",
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user