fix caching for img2imgalt
This commit is contained in:
parent
91c56c51c7
commit
d51847c184
@ -1,3 +1,5 @@
|
|||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
@ -56,9 +58,14 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
|
|||||||
|
|
||||||
return x / x.std()
|
return x / x.std()
|
||||||
|
|
||||||
cache = [None, None, None, None, None]
|
|
||||||
|
Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt"])
|
||||||
|
|
||||||
|
|
||||||
class Script(scripts.Script):
|
class Script(scripts.Script):
|
||||||
|
def __init__(self):
|
||||||
|
self.cache = None
|
||||||
|
|
||||||
def title(self):
|
def title(self):
|
||||||
return "img2img alternative test"
|
return "img2img alternative test"
|
||||||
|
|
||||||
@ -67,7 +74,7 @@ class Script(scripts.Script):
|
|||||||
|
|
||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
original_prompt = gr.Textbox(label="Original prompt", lines=1)
|
original_prompt = gr.Textbox(label="Original prompt", lines=1)
|
||||||
cfg = gr.Slider(label="Decode CFG scale", minimum=0.1, maximum=3.0, step=0.1, value=1.0)
|
cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0)
|
||||||
st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50)
|
st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50)
|
||||||
|
|
||||||
return [original_prompt, cfg, st]
|
return [original_prompt, cfg, st]
|
||||||
@ -77,19 +84,18 @@ class Script(scripts.Script):
|
|||||||
p.batch_count = 1
|
p.batch_count = 1
|
||||||
|
|
||||||
def sample_extra(x, conditioning, unconditional_conditioning):
|
def sample_extra(x, conditioning, unconditional_conditioning):
|
||||||
lat = tuple([int(x*10) for x in p.init_latent.cpu().numpy().flatten().tolist()])
|
lat = (p.init_latent.cpu().numpy() * 10).astype(int)
|
||||||
|
|
||||||
if cache[0] is not None and cache[1] == cfg and cache[2] == st and len(cache[3]) == len(lat) and sum(np.array(cache[3])-np.array(lat)) < 100 and cache[4] == original_prompt:
|
same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st and self.cache.original_prompt == original_prompt
|
||||||
noise = cache[0]
|
same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100
|
||||||
|
|
||||||
|
if same_everything:
|
||||||
|
noise = self.cache.noise
|
||||||
else:
|
else:
|
||||||
shared.state.job_count += 1
|
shared.state.job_count += 1
|
||||||
cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt])
|
cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt])
|
||||||
noise = find_noise_for_image(p, cond, unconditional_conditioning, cfg, st)
|
noise = find_noise_for_image(p, cond, unconditional_conditioning, cfg, st)
|
||||||
cache[0] = noise
|
self.cache = Cached(noise, cfg, st, lat, original_prompt)
|
||||||
cache[1] = cfg
|
|
||||||
cache[2] = st
|
|
||||||
cache[3] = lat
|
|
||||||
cache[4] = original_prompt
|
|
||||||
|
|
||||||
sampler = samplers[p.sampler_index].constructor(p.sd_model)
|
sampler = samplers[p.sampler_index].constructor(p.sd_model)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user