fix conds caching with extra network

This commit is contained in:
w-e-w 2023-06-04 04:24:44 +09:00
parent 30bbb8bce3
commit f098e726d3
2 changed files with 15 additions and 12 deletions

View File

@ -32,6 +32,9 @@ class ExtraNetworkParams:
else: else:
self.positional.append(item) self.positional.append(item)
def __eq__(self, other):
return self.items == other.items and self.positional == other.positional and self.named == other.named
class ExtraNetwork: class ExtraNetwork:
def __init__(self, name): def __init__(self, name):

View File

@ -171,6 +171,7 @@ class StableDiffusionProcessing:
self.prompts = None self.prompts = None
self.negative_prompts = None self.negative_prompts = None
self.extra_network_data = None
self.seeds = None self.seeds = None
self.subseeds = None self.subseeds = None
@ -311,7 +312,7 @@ class StableDiffusionProcessing:
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts] self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts] self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
def get_conds_with_caching(self, function, required_prompts, steps, cache): def get_conds_with_caching(self, function, required_prompts, steps, cache, extra_network_data):
""" """
Returns the result of calling function(shared.sd_model, required_prompts, steps) Returns the result of calling function(shared.sd_model, required_prompts, steps)
using a cache to store the result if the same arguments have been used before. using a cache to store the result if the same arguments have been used before.
@ -321,21 +322,21 @@ class StableDiffusionProcessing:
have been used before. The second element is where the previously have been used before. The second element is where the previously
computed result is stored. computed result is stored.
""" """
if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]: if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
return cache[1] return cache[1]
with devices.autocast(): with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps) cache[1] = function(shared.sd_model, required_prompts, steps)
cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
return cache[1] return cache[1]
def setup_conds(self): def setup_conds(self):
sampler_config = sd_samplers.find_sampler_config(self.sampler_name) sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc) self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c, self.extra_network_data)
def parse_extra_network_prompts(self): def parse_extra_network_prompts(self):
self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts) self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts)
@ -681,7 +682,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.job_count == -1: if state.job_count == -1:
state.job_count = p.n_iter state.job_count = p.n_iter
extra_network_data = None
for n in range(p.n_iter): for n in range(p.n_iter):
p.iteration = n p.iteration = n
@ -702,11 +702,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if len(p.prompts) == 0: if len(p.prompts) == 0:
break break
extra_network_data = p.parse_extra_network_prompts() p.extra_network_data = p.parse_extra_network_prompts()
if not p.disable_extra_networks: if not p.disable_extra_networks:
with devices.autocast(): with devices.autocast():
extra_networks.activate(p, extra_network_data) extra_networks.activate(p, p.extra_network_data)
if p.scripts is not None: if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
@ -828,8 +828,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
if not p.disable_extra_networks and extra_network_data: if not p.disable_extra_networks and p.extra_network_data:
extra_networks.deactivate(p, extra_network_data) extra_networks.deactivate(p, p.extra_network_data)
devices.torch_gc() devices.torch_gc()
@ -1101,8 +1101,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
super().setup_conds() super().setup_conds()
if self.enable_hr: if self.enable_hr:
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc) self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.hr_extra_network_data)
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c) self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c, self.hr_extra_network_data)
def parse_extra_network_prompts(self): def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts() res = super().parse_extra_network_prompts()