fix conds caching with extra network
This commit is contained in:
parent
30bbb8bce3
commit
f098e726d3
@ -32,6 +32,9 @@ class ExtraNetworkParams:
|
|||||||
else:
|
else:
|
||||||
self.positional.append(item)
|
self.positional.append(item)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.items == other.items and self.positional == other.positional and self.named == other.named
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetwork:
|
class ExtraNetwork:
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
|
@ -171,6 +171,7 @@ class StableDiffusionProcessing:
|
|||||||
|
|
||||||
self.prompts = None
|
self.prompts = None
|
||||||
self.negative_prompts = None
|
self.negative_prompts = None
|
||||||
|
self.extra_network_data = None
|
||||||
self.seeds = None
|
self.seeds = None
|
||||||
self.subseeds = None
|
self.subseeds = None
|
||||||
|
|
||||||
@ -311,7 +312,7 @@ class StableDiffusionProcessing:
|
|||||||
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
|
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
|
||||||
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
|
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
|
||||||
|
|
||||||
def get_conds_with_caching(self, function, required_prompts, steps, cache):
|
def get_conds_with_caching(self, function, required_prompts, steps, cache, extra_network_data):
|
||||||
"""
|
"""
|
||||||
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
||||||
using a cache to store the result if the same arguments have been used before.
|
using a cache to store the result if the same arguments have been used before.
|
||||||
@ -321,21 +322,21 @@ class StableDiffusionProcessing:
|
|||||||
have been used before. The second element is where the previously
|
have been used before. The second element is where the previously
|
||||||
computed result is stored.
|
computed result is stored.
|
||||||
"""
|
"""
|
||||||
if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]:
|
if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
|
||||||
return cache[1]
|
return cache[1]
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
cache[1] = function(shared.sd_model, required_prompts, steps)
|
cache[1] = function(shared.sd_model, required_prompts, steps)
|
||||||
|
|
||||||
cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info)
|
cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
|
||||||
return cache[1]
|
return cache[1]
|
||||||
|
|
||||||
def setup_conds(self):
|
def setup_conds(self):
|
||||||
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
||||||
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
||||||
|
|
||||||
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc)
|
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.extra_network_data)
|
||||||
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c)
|
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c, self.extra_network_data)
|
||||||
|
|
||||||
def parse_extra_network_prompts(self):
|
def parse_extra_network_prompts(self):
|
||||||
self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts)
|
self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts)
|
||||||
@ -681,7 +682,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if state.job_count == -1:
|
if state.job_count == -1:
|
||||||
state.job_count = p.n_iter
|
state.job_count = p.n_iter
|
||||||
|
|
||||||
extra_network_data = None
|
|
||||||
for n in range(p.n_iter):
|
for n in range(p.n_iter):
|
||||||
p.iteration = n
|
p.iteration = n
|
||||||
|
|
||||||
@ -702,11 +702,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if len(p.prompts) == 0:
|
if len(p.prompts) == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
extra_network_data = p.parse_extra_network_prompts()
|
p.extra_network_data = p.parse_extra_network_prompts()
|
||||||
|
|
||||||
if not p.disable_extra_networks:
|
if not p.disable_extra_networks:
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
extra_networks.activate(p, extra_network_data)
|
extra_networks.activate(p, p.extra_network_data)
|
||||||
|
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
||||||
@ -828,8 +828,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if opts.grid_save:
|
if opts.grid_save:
|
||||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||||
|
|
||||||
if not p.disable_extra_networks and extra_network_data:
|
if not p.disable_extra_networks and p.extra_network_data:
|
||||||
extra_networks.deactivate(p, extra_network_data)
|
extra_networks.deactivate(p, p.extra_network_data)
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
@ -1101,8 +1101,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
super().setup_conds()
|
super().setup_conds()
|
||||||
|
|
||||||
if self.enable_hr:
|
if self.enable_hr:
|
||||||
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc)
|
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.hr_extra_network_data)
|
||||||
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c)
|
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c, self.hr_extra_network_data)
|
||||||
|
|
||||||
def parse_extra_network_prompts(self):
|
def parse_extra_network_prompts(self):
|
||||||
res = super().parse_extra_network_prompts()
|
res = super().parse_extra_network_prompts()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user