Add Tiny AE live preview

This commit is contained in:
Sakura-Luna 2023-05-14 12:42:44 +08:00
parent b08500cec8
commit e14b586d04
4 changed files with 101 additions and 9 deletions

View File

@ -2,7 +2,7 @@ from collections import namedtuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from modules import devices, processing, images, sd_vae_approx from modules import devices, processing, images, sd_vae_approx, sd_vae_taesd
from modules.shared import opts, state from modules.shared import opts, state
import modules.shared as shared import modules.shared as shared
@ -22,21 +22,26 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc return steps, t_enc
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2} approximation_indexes = {"Full": 0, "Tiny AE": 1, "Approx NN": 2, "Approx cheap": 3}
def single_sample_to_image(sample, approximation=None): def single_sample_to_image(sample, approximation=None):
if approximation is None: if approximation is None:
approximation = approximation_indexes.get(opts.show_progress_type, 0) approximation = approximation_indexes.get(opts.show_progress_type, 0)
if approximation == 2: if approximation == 1:
x_sample = sd_vae_approx.cheap_approximation(sample) x_sample = sd_vae_taesd.decode()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
elif approximation == 1: x_sample = sd_vae_taesd.TAESD.unscale_latents(x_sample)
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() x_sample = torch.clamp((x_sample * 0.25) + 0.5, 0, 1)
else: else:
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] if approximation == 3:
x_sample = sd_vae_approx.cheap_approximation(sample)
elif approximation == 2:
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
else:
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample) return Image.fromarray(x_sample)

76
modules/sd_vae_taesd.py Normal file
View File

@ -0,0 +1,76 @@
"""
Tiny AutoEncoder for Stable Diffusion
(DNN for encoding / decoding SD's latent space)
https://github.com/madebyollin/taesd
"""
import os
import torch
import torch.nn as nn
from modules import devices, paths_internal
sd_vae_taesd = None
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
@staticmethod
def forward(x):
return torch.tanh(x / 3) * 3
class Block(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU()
def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))
def decoder():
return nn.Sequential(
Clamp(), conv(4, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), conv(64, 3),
)
class TAESD(nn.Module):
latent_magnitude = 2
latent_shift = 0.5
def __init__(self, decoder_path="taesd_decoder.pth"):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.decoder = decoder()
self.decoder.load_state_dict(
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
@staticmethod
def unscale_latents(x):
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
def decode():
global sd_vae_taesd
if sd_vae_taesd is None:
model_path = os.path.join(paths_internal.models_path, "VAE-approx", "taesd_decoder.pth")
if os.path.exists(model_path):
sd_vae_taesd = TAESD(model_path)
sd_vae_taesd.eval()
sd_vae_taesd.to(devices.device, devices.dtype)
else:
raise FileNotFoundError('Tiny AE mdoel not found')
return sd_vae_taesd.decoder

View File

@ -425,7 +425,7 @@ options_templates.update(options_section(('ui', "Live previews"), {
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"), "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), "show_progress_type": OptionInfo("Tiny AE", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Tiny AE", "Approx NN", "Approx cheap"]}),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
"live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds") "live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
})) }))

View File

@ -144,10 +144,21 @@ Use --skip-version-check commandline argument to disable this check.
""".strip()) """.strip())
def check_taesd():
from modules.paths_internal import models_path
model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
model_path = os.path.join(models_path, "VAE-approx", "taesd_decoder.pth")
if not os.path.exists(model_path):
print('download taesd model')
torch.hub.download_url_to_file(model_url, os.path.dirname(model_path))
def initialize(): def initialize():
fix_asyncio_event_loop_policy() fix_asyncio_event_loop_policy()
check_versions() check_versions()
check_taesd()
extensions.list_extensions() extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir) localization.list_localizations(cmd_opts.localizations_dir)