ruff manual fixes
This commit is contained in:
parent
028d3f6425
commit
550256db1c
@ -24,7 +24,7 @@ class VQModel(pl.LightningModule):
|
|||||||
n_embed,
|
n_embed,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
ckpt_path=None,
|
ckpt_path=None,
|
||||||
ignore_keys=[],
|
ignore_keys=None,
|
||||||
image_key="image",
|
image_key="image",
|
||||||
colorize_nlabels=None,
|
colorize_nlabels=None,
|
||||||
monitor=None,
|
monitor=None,
|
||||||
@ -62,7 +62,7 @@ class VQModel(pl.LightningModule):
|
|||||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.lr_g_factor = lr_g_factor
|
self.lr_g_factor = lr_g_factor
|
||||||
|
|
||||||
@ -81,11 +81,11 @@ class VQModel(pl.LightningModule):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
def init_from_ckpt(self, path, ignore_keys=None):
|
||||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
for ik in ignore_keys:
|
for ik in ignore_keys or []:
|
||||||
if k.startswith(ik):
|
if k.startswith(ik):
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
del sd[k]
|
del sd[k]
|
||||||
@ -270,7 +270,7 @@ class VQModel(pl.LightningModule):
|
|||||||
|
|
||||||
class VQModelInterface(VQModel):
|
class VQModelInterface(VQModel):
|
||||||
def __init__(self, embed_dim, *args, **kwargs):
|
def __init__(self, embed_dim, *args, **kwargs):
|
||||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
super().__init__(*args, embed_dim=embed_dim, **kwargs)
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
|
@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
beta_schedule="linear",
|
beta_schedule="linear",
|
||||||
loss_type="l2",
|
loss_type="l2",
|
||||||
ckpt_path=None,
|
ckpt_path=None,
|
||||||
ignore_keys=[],
|
ignore_keys=None,
|
||||||
load_only_unet=False,
|
load_only_unet=False,
|
||||||
monitor="val/loss",
|
monitor="val/loss",
|
||||||
use_ema=True,
|
use_ema=True,
|
||||||
@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
if monitor is not None:
|
if monitor is not None:
|
||||||
self.monitor = monitor
|
self.monitor = monitor
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
||||||
|
|
||||||
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
||||||
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||||
@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
||||||
sd = torch.load(path, map_location="cpu")
|
sd = torch.load(path, map_location="cpu")
|
||||||
if "state_dict" in list(sd.keys()):
|
if "state_dict" in list(sd.keys()):
|
||||||
sd = sd["state_dict"]
|
sd = sd["state_dict"]
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
for ik in ignore_keys:
|
for ik in ignore_keys or []:
|
||||||
if k.startswith(ik):
|
if k.startswith(ik):
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
del sd[k]
|
del sd[k]
|
||||||
@ -444,7 +444,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
conditioning_key = None
|
conditioning_key = None
|
||||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||||
ignore_keys = kwargs.pop("ignore_keys", [])
|
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||||
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
|
super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
|
||||||
self.concat_mode = concat_mode
|
self.concat_mode = concat_mode
|
||||||
self.cond_stage_trainable = cond_stage_trainable
|
self.cond_stage_trainable = cond_stage_trainable
|
||||||
self.cond_stage_key = cond_stage_key
|
self.cond_stage_key = cond_stage_key
|
||||||
@ -1418,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
|
|||||||
# TODO: move all layout-specific hacks to this class
|
# TODO: move all layout-specific hacks to this class
|
||||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||||
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
||||||
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
||||||
|
|
||||||
def log_images(self, batch, N=8, *args, **kwargs):
|
def log_images(self, batch, N=8, *args, **kwargs):
|
||||||
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
||||||
|
|
||||||
key = 'train' if self.training else 'validation'
|
key = 'train' if self.training else 'validation'
|
||||||
dset = self.trainer.datamodule.datasets[key]
|
dset = self.trainer.datamodule.datasets[key]
|
||||||
|
@ -644,13 +644,17 @@ class SwinIR(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||||
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
embed_dim=96, depths=None, num_heads=None,
|
||||||
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||||
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(SwinIR, self).__init__()
|
super(SwinIR, self).__init__()
|
||||||
|
|
||||||
|
depths = depths or [6, 6, 6, 6]
|
||||||
|
num_heads = num_heads or [6, 6, 6, 6]
|
||||||
|
|
||||||
num_in_ch = in_chans
|
num_in_ch = in_chans
|
||||||
num_out_ch = in_chans
|
num_out_ch = in_chans
|
||||||
num_feat = 64
|
num_feat = 64
|
||||||
|
@ -74,9 +74,12 @@ class WindowAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
||||||
pretrained_window_size=[0, 0]):
|
pretrained_window_size=None):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
pretrained_window_size = pretrained_window_size or [0, 0]
|
||||||
|
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.window_size = window_size # Wh, Ww
|
self.window_size = window_size # Wh, Ww
|
||||||
self.pretrained_window_size = pretrained_window_size
|
self.pretrained_window_size = pretrained_window_size
|
||||||
@ -698,13 +701,17 @@ class Swin2SR(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||||
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
embed_dim=96, depths=None, num_heads=None,
|
||||||
window_size=7, mlp_ratio=4., qkv_bias=True,
|
window_size=7, mlp_ratio=4., qkv_bias=True,
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||||
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(Swin2SR, self).__init__()
|
super(Swin2SR, self).__init__()
|
||||||
|
|
||||||
|
depths = depths or [6, 6, 6, 6]
|
||||||
|
num_heads = num_heads or [6, 6, 6, 6]
|
||||||
|
|
||||||
num_in_ch = in_chans
|
num_in_ch = in_chans
|
||||||
num_out_ch = in_chans
|
num_out_ch = in_chans
|
||||||
num_feat = 64
|
num_feat = 64
|
||||||
|
@ -34,14 +34,16 @@ import piexif.helper
|
|||||||
def upscaler_to_index(name: str):
|
def upscaler_to_index(name: str):
|
||||||
try:
|
try:
|
||||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
||||||
except Exception:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}")
|
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
|
||||||
|
|
||||||
|
|
||||||
def script_name_to_index(name, scripts):
|
def script_name_to_index(name, scripts):
|
||||||
try:
|
try:
|
||||||
return [script.title().lower() for script in scripts].index(name.lower())
|
return [script.title().lower() for script in scripts].index(name.lower())
|
||||||
except Exception:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
|
raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
|
||||||
|
|
||||||
|
|
||||||
def validate_sampler_name(name):
|
def validate_sampler_name(name):
|
||||||
config = sd_samplers.all_samplers_map.get(name, None)
|
config = sd_samplers.all_samplers_map.get(name, None)
|
||||||
@ -50,20 +52,23 @@ def validate_sampler_name(name):
|
|||||||
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def setUpscalers(req: dict):
|
def setUpscalers(req: dict):
|
||||||
reqDict = vars(req)
|
reqDict = vars(req)
|
||||||
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
||||||
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
||||||
return reqDict
|
return reqDict
|
||||||
|
|
||||||
|
|
||||||
def decode_base64_to_image(encoding):
|
def decode_base64_to_image(encoding):
|
||||||
if encoding.startswith("data:image/"):
|
if encoding.startswith("data:image/"):
|
||||||
encoding = encoding.split(";")[1].split(",")[1]
|
encoding = encoding.split(";")[1].split(",")[1]
|
||||||
try:
|
try:
|
||||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||||
return image
|
return image
|
||||||
except Exception:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
||||||
|
|
||||||
|
|
||||||
def encode_pil_to_base64(image):
|
def encode_pil_to_base64(image):
|
||||||
with io.BytesIO() as output_bytes:
|
with io.BytesIO() as output_bytes:
|
||||||
@ -94,6 +99,7 @@ def encode_pil_to_base64(image):
|
|||||||
|
|
||||||
return base64.b64encode(bytes_data)
|
return base64.b64encode(bytes_data)
|
||||||
|
|
||||||
|
|
||||||
def api_middleware(app: FastAPI):
|
def api_middleware(app: FastAPI):
|
||||||
rich_available = True
|
rich_available = True
|
||||||
try:
|
try:
|
||||||
|
@ -161,10 +161,13 @@ class Fuse_sft_block(nn.Module):
|
|||||||
class CodeFormer(VQAutoEncoder):
|
class CodeFormer(VQAutoEncoder):
|
||||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||||
codebook_size=1024, latent_size=256,
|
codebook_size=1024, latent_size=256,
|
||||||
connect_list=['32', '64', '128', '256'],
|
connect_list=None,
|
||||||
fix_modules=['quantize','generator']):
|
fix_modules=None):
|
||||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||||
|
|
||||||
|
connect_list = connect_list or ['32', '64', '128', '256']
|
||||||
|
fix_modules = fix_modules or ['quantize', 'generator']
|
||||||
|
|
||||||
if fix_modules is not None:
|
if fix_modules is not None:
|
||||||
for module in fix_modules:
|
for module in fix_modules:
|
||||||
for param in getattr(self, module).parameters():
|
for param in getattr(self, module).parameters():
|
||||||
|
@ -326,7 +326,7 @@ class Generator(nn.Module):
|
|||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
@ARCH_REGISTRY.register()
|
||||||
class VQAutoEncoder(nn.Module):
|
class VQAutoEncoder(nn.Module):
|
||||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
|
||||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
logger = get_root_logger()
|
logger = get_root_logger()
|
||||||
@ -337,7 +337,7 @@ class VQAutoEncoder(nn.Module):
|
|||||||
self.embed_dim = emb_dim
|
self.embed_dim = emb_dim
|
||||||
self.ch_mult = ch_mult
|
self.ch_mult = ch_mult
|
||||||
self.resolution = img_size
|
self.resolution = img_size
|
||||||
self.attn_resolutions = attn_resolutions
|
self.attn_resolutions = attn_resolutions or [16]
|
||||||
self.quantizer_type = quantizer
|
self.quantizer_type = quantizer
|
||||||
self.encoder = Encoder(
|
self.encoder = Encoder(
|
||||||
self.in_channels,
|
self.in_channels,
|
||||||
|
@ -19,14 +19,14 @@ registered_param_bindings = []
|
|||||||
|
|
||||||
|
|
||||||
class ParamBinding:
|
class ParamBinding:
|
||||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
|
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
|
||||||
self.paste_button = paste_button
|
self.paste_button = paste_button
|
||||||
self.tabname = tabname
|
self.tabname = tabname
|
||||||
self.source_text_component = source_text_component
|
self.source_text_component = source_text_component
|
||||||
self.source_image_component = source_image_component
|
self.source_image_component = source_image_component
|
||||||
self.source_tabname = source_tabname
|
self.source_tabname = source_tabname
|
||||||
self.override_settings_component = override_settings_component
|
self.override_settings_component = override_settings_component
|
||||||
self.paste_field_names = paste_field_names
|
self.paste_field_names = paste_field_names or []
|
||||||
|
|
||||||
|
|
||||||
def reset():
|
def reset():
|
||||||
|
@ -52,7 +52,7 @@ class DDPM(pl.LightningModule):
|
|||||||
beta_schedule="linear",
|
beta_schedule="linear",
|
||||||
loss_type="l2",
|
loss_type="l2",
|
||||||
ckpt_path=None,
|
ckpt_path=None,
|
||||||
ignore_keys=[],
|
ignore_keys=None,
|
||||||
load_only_unet=False,
|
load_only_unet=False,
|
||||||
monitor="val/loss",
|
monitor="val/loss",
|
||||||
use_ema=True,
|
use_ema=True,
|
||||||
@ -107,7 +107,7 @@ class DDPM(pl.LightningModule):
|
|||||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
||||||
|
|
||||||
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
||||||
if self.use_ema and not load_ema:
|
if self.use_ema and not load_ema:
|
||||||
@ -194,7 +194,9 @@ class DDPM(pl.LightningModule):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
||||||
|
ignore_keys = ignore_keys or []
|
||||||
|
|
||||||
sd = torch.load(path, map_location="cpu")
|
sd = torch.load(path, map_location="cpu")
|
||||||
if "state_dict" in list(sd.keys()):
|
if "state_dict" in list(sd.keys()):
|
||||||
sd = sd["state_dict"]
|
sd = sd["state_dict"]
|
||||||
@ -473,7 +475,7 @@ class LatentDiffusion(DDPM):
|
|||||||
conditioning_key = None
|
conditioning_key = None
|
||||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||||
ignore_keys = kwargs.pop("ignore_keys", [])
|
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||||
super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
|
super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
|
||||||
self.concat_mode = concat_mode
|
self.concat_mode = concat_mode
|
||||||
self.cond_stage_trainable = cond_stage_trainable
|
self.cond_stage_trainable = cond_stage_trainable
|
||||||
self.cond_stage_key = cond_stage_key
|
self.cond_stage_key = cond_stage_key
|
||||||
@ -1433,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion):
|
|||||||
# TODO: move all layout-specific hacks to this class
|
# TODO: move all layout-specific hacks to this class
|
||||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||||
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
||||||
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
||||||
|
|
||||||
def log_images(self, batch, N=8, *args, **kwargs):
|
def log_images(self, batch, N=8, *args, **kwargs):
|
||||||
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
||||||
|
|
||||||
key = 'train' if self.training else 'validation'
|
key = 'train' if self.training else 'validation'
|
||||||
dset = self.trainer.datamodule.datasets[key]
|
dset = self.trainer.datamodule.datasets[key]
|
||||||
|
@ -178,13 +178,13 @@ def model_wrapper(
|
|||||||
model,
|
model,
|
||||||
noise_schedule,
|
noise_schedule,
|
||||||
model_type="noise",
|
model_type="noise",
|
||||||
model_kwargs={},
|
model_kwargs=None,
|
||||||
guidance_type="uncond",
|
guidance_type="uncond",
|
||||||
#condition=None,
|
#condition=None,
|
||||||
#unconditional_condition=None,
|
#unconditional_condition=None,
|
||||||
guidance_scale=1.,
|
guidance_scale=1.,
|
||||||
classifier_fn=None,
|
classifier_fn=None,
|
||||||
classifier_kwargs={},
|
classifier_kwargs=None,
|
||||||
):
|
):
|
||||||
"""Create a wrapper function for the noise prediction model.
|
"""Create a wrapper function for the noise prediction model.
|
||||||
|
|
||||||
@ -275,6 +275,9 @@ def model_wrapper(
|
|||||||
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_kwargs = model_kwargs or []
|
||||||
|
classifier_kwargs = classifier_kwargs or []
|
||||||
|
|
||||||
def get_model_input_time(t_continuous):
|
def get_model_input_time(t_continuous):
|
||||||
"""
|
"""
|
||||||
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
||||||
|
@ -104,7 +104,7 @@ def check_pt(filename, extra_handler):
|
|||||||
|
|
||||||
|
|
||||||
def load(filename, *args, **kwargs):
|
def load(filename, *args, **kwargs):
|
||||||
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
|
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||||
|
@ -55,7 +55,7 @@ class VanillaStableDiffusionSampler:
|
|||||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||||
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
||||||
|
|
||||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
|
||||||
|
|
||||||
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ class EmbeddingEncoder(json.JSONEncoder):
|
|||||||
|
|
||||||
class EmbeddingDecoder(json.JSONDecoder):
|
class EmbeddingDecoder(json.JSONDecoder):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
|
json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
|
||||||
|
|
||||||
def object_hook(self, d):
|
def object_hook(self, d):
|
||||||
if 'TORCHTENSOR' in d:
|
if 'TORCHTENSOR' in d:
|
||||||
|
@ -32,8 +32,8 @@ class LearnScheduleIterator:
|
|||||||
self.maxit += 1
|
self.maxit += 1
|
||||||
return
|
return
|
||||||
assert self.rates
|
assert self.rates
|
||||||
except (ValueError, AssertionError):
|
except (ValueError, AssertionError) as e:
|
||||||
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
|
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
|
||||||
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
@ -24,6 +24,9 @@ ignore = [
|
|||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
[tool.ruff.per-file-ignores]
|
[tool.ruff.per-file-ignores]
|
||||||
"webui.py" = ["E402"] # Module level import not at top of file
|
"webui.py" = ["E402"] # Module level import not at top of file
|
||||||
|
|
||||||
|
[tool.ruff.flake8-bugbear]
|
||||||
|
# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
|
||||||
|
extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"]
|
Loading…
Reference in New Issue
Block a user