WebUI/modules/sd_models_xl.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

41 lines
1.4 KiB
Python
Raw Normal View History

2023-07-11 18:16:43 +00:00
from __future__ import annotations
import torch
import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer
from modules import devices
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: list[str]):
for embedder in self.conditioner.embedders:
embedder.ucg_rate = 0.0
c = self.conditioner({'txt': batch})
return c
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
return self.model(x, t, cond)
def extend_sdxl(model):
dtype = next(model.model.diffusion_model.parameters()).dtype
model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn'
model.cond_stage_model = [x for x in model.conditioner.embedders if type(x).__name__ == 'FrozenOpenCLIPEmbedder'][0]
model.cond_stage_key = model.cond_stage_model.input_key
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model