41 lines
1.4 KiB
Python
41 lines
1.4 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import torch
|
||
|
|
||
|
import sgm.models.diffusion
|
||
|
import sgm.modules.diffusionmodules.denoiser_scaling
|
||
|
import sgm.modules.diffusionmodules.discretizer
|
||
|
from modules import devices
|
||
|
|
||
|
|
||
|
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: list[str]):
|
||
|
for embedder in self.conditioner.embedders:
|
||
|
embedder.ucg_rate = 0.0
|
||
|
|
||
|
c = self.conditioner({'txt': batch})
|
||
|
|
||
|
return c
|
||
|
|
||
|
|
||
|
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
||
|
return self.model(x, t, cond)
|
||
|
|
||
|
|
||
|
def extend_sdxl(model):
|
||
|
dtype = next(model.model.diffusion_model.parameters()).dtype
|
||
|
model.model.diffusion_model.dtype = dtype
|
||
|
model.model.conditioning_key = 'crossattn'
|
||
|
|
||
|
model.cond_stage_model = [x for x in model.conditioner.embedders if type(x).__name__ == 'FrozenOpenCLIPEmbedder'][0]
|
||
|
model.cond_stage_key = model.cond_stage_model.input_key
|
||
|
|
||
|
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
||
|
|
||
|
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||
|
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
|
||
|
|
||
|
|
||
|
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
||
|
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
||
|
|