initial SDXL refiner support
This commit is contained in:
parent
dc39061856
commit
6d8dcdefa0
@ -180,21 +180,29 @@ class StableDiffusionModelHijack:
|
|||||||
def hijack(self, m):
|
def hijack(self, m):
|
||||||
conditioner = getattr(m, 'conditioner', None)
|
conditioner = getattr(m, 'conditioner', None)
|
||||||
if conditioner:
|
if conditioner:
|
||||||
|
text_cond_models = []
|
||||||
|
|
||||||
for i in range(len(conditioner.embedders)):
|
for i in range(len(conditioner.embedders)):
|
||||||
embedder = conditioner.embedders[i]
|
embedder = conditioner.embedders[i]
|
||||||
typename = type(embedder).__name__
|
typename = type(embedder).__name__
|
||||||
if typename == 'FrozenOpenCLIPEmbedder':
|
if typename == 'FrozenOpenCLIPEmbedder':
|
||||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
||||||
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
|
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
|
||||||
conditioner.embedders[i] = m.cond_stage_model
|
text_cond_models.append(conditioner.embedders[i])
|
||||||
if typename == 'FrozenCLIPEmbedder':
|
if typename == 'FrozenCLIPEmbedder':
|
||||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
model_embeddings = embedder.transformer.text_model.embeddings
|
||||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||||
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
|
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
|
||||||
conditioner.embedders[i] = m.cond_stage_model
|
text_cond_models.append(conditioner.embedders[i])
|
||||||
if typename == 'FrozenOpenCLIPEmbedder2':
|
if typename == 'FrozenOpenCLIPEmbedder2':
|
||||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
||||||
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
|
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
|
||||||
|
text_cond_models.append(conditioner.embedders[i])
|
||||||
|
|
||||||
|
if len(text_cond_models) == 1:
|
||||||
|
m.cond_stage_model = text_cond_models[0]
|
||||||
|
else:
|
||||||
|
m.cond_stage_model = conditioner
|
||||||
|
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||||
|
@ -414,6 +414,7 @@ def repair_config(sd_config):
|
|||||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||||
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
|
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
|
||||||
|
sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
|
||||||
|
|
||||||
|
|
||||||
class SdModelData:
|
class SdModelData:
|
||||||
@ -477,7 +478,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
|
||||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||||
clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict or sdxl_clip_weight in state_dict
|
clip_is_included_into_sd = any([x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict])
|
||||||
|
|
||||||
timer.record("find config")
|
timer.record("find config")
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
|||||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||||
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
||||||
|
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
||||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||||
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||||
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||||
@ -72,6 +73,8 @@ def guess_model_config_from_state_dict(sd, filename):
|
|||||||
|
|
||||||
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||||
return config_sdxl
|
return config_sdxl
|
||||||
|
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
||||||
|
return config_sdxl_refiner
|
||||||
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||||
return config_depth_model
|
return config_depth_model
|
||||||
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
||||||
|
@ -14,15 +14,20 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
|
|||||||
|
|
||||||
width = getattr(self, 'target_width', 1024)
|
width = getattr(self, 'target_width', 1024)
|
||||||
height = getattr(self, 'target_height', 1024)
|
height = getattr(self, 'target_height', 1024)
|
||||||
|
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||||
|
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||||
|
|
||||||
|
devices_args = dict(device=devices.device, dtype=devices.dtype)
|
||||||
|
|
||||||
sdxl_conds = {
|
sdxl_conds = {
|
||||||
"txt": batch,
|
"txt": batch,
|
||||||
"original_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype),
|
"original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||||
"crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left]).repeat(len(batch), 1).to(devices.device, devices.dtype),
|
"crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
|
||||||
"target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype),
|
"target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||||
|
"aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
force_zero_negative_prompt = getattr(batch, 'is_negative_prompt', False) and all(x == '' for x in batch)
|
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
|
||||||
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
|
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
|
||||||
|
|
||||||
return c
|
return c
|
||||||
@ -35,25 +40,55 @@ def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
|||||||
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
||||||
|
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
||||||
|
sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
|
||||||
|
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
|
||||||
|
encoded = embedder.encode_embedding_init_text(init_text, nvpt)
|
||||||
|
res.append(encoded)
|
||||||
|
|
||||||
|
return torch.cat(res, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def process_texts(self, texts):
|
||||||
|
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
||||||
|
return embedder.process_texts(texts)
|
||||||
|
|
||||||
|
|
||||||
|
def get_target_prompt_token_count(self, token_count):
|
||||||
|
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
|
||||||
|
return embedder.get_target_prompt_token_count(token_count)
|
||||||
|
|
||||||
|
|
||||||
|
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
||||||
|
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
||||||
|
sgm.modules.GeneralConditioner.process_texts = process_texts
|
||||||
|
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
||||||
|
|
||||||
|
|
||||||
def extend_sdxl(model):
|
def extend_sdxl(model):
|
||||||
|
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
||||||
|
|
||||||
dtype = next(model.model.diffusion_model.parameters()).dtype
|
dtype = next(model.model.diffusion_model.parameters()).dtype
|
||||||
model.model.diffusion_model.dtype = dtype
|
model.model.diffusion_model.dtype = dtype
|
||||||
model.model.conditioning_key = 'crossattn'
|
model.model.conditioning_key = 'crossattn'
|
||||||
|
model.cond_stage_key = 'txt'
|
||||||
model.cond_stage_model = [x for x in model.conditioner.embedders if 'CLIPEmbedder' in type(x).__name__][0]
|
# model.cond_stage_model will be set in sd_hijack
|
||||||
model.cond_stage_key = model.cond_stage_model.input_key
|
|
||||||
|
|
||||||
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
||||||
|
|
||||||
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||||
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
|
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
|
||||||
|
|
||||||
|
model.conditioner.wrapped = torch.nn.Module()
|
||||||
|
|
||||||
|
|
||||||
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
|
||||||
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
|
||||||
sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
|
|
||||||
|
|
||||||
sgm.modules.attention.print = lambda *args: None
|
sgm.modules.attention.print = lambda *args: None
|
||||||
sgm.modules.diffusionmodules.model.print = lambda *args: None
|
sgm.modules.diffusionmodules.model.print = lambda *args: None
|
||||||
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
|
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
|
||||||
|
@ -428,8 +428,13 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
||||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
||||||
"sdxl_crop_top": OptionInfo(0, "SDXL top coordinate of the crop"),
|
}))
|
||||||
"sdxl_crop_left": OptionInfo(0, "SDXL left coordinate of the crop"),
|
|
||||||
|
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
||||||
|
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
||||||
|
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
||||||
|
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
||||||
|
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||||
|
Loading…
Reference in New Issue
Block a user