ruff manual fixes
This commit is contained in:
parent
028d3f6425
commit
550256db1c
@ -24,7 +24,7 @@ class VQModel(pl.LightningModule):
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
ignore_keys=None,
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
@ -62,7 +62,7 @@ class VQModel(pl.LightningModule):
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lr_g_factor = lr_g_factor
|
||||
|
||||
@ -81,11 +81,11 @@ class VQModel(pl.LightningModule):
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
def init_from_ckpt(self, path, ignore_keys=None):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
for ik in ignore_keys or []:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
@ -270,7 +270,7 @@ class VQModel(pl.LightningModule):
|
||||
|
||||
class VQModelInterface(VQModel):
|
||||
def __init__(self, embed_dim, *args, **kwargs):
|
||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||
super().__init__(*args, embed_dim=embed_dim, **kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def encode(self, x):
|
||||
|
@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
|
||||
beta_schedule="linear",
|
||||
loss_type="l2",
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
ignore_keys=None,
|
||||
load_only_unet=False,
|
||||
monitor="val/loss",
|
||||
use_ema=True,
|
||||
@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
||||
|
||||
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
||||
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||
@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule):
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
if "state_dict" in list(sd.keys()):
|
||||
sd = sd["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
for ik in ignore_keys or []:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
@ -444,7 +444,7 @@ class LatentDiffusionV1(DDPMV1):
|
||||
conditioning_key = None
|
||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
|
||||
super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
|
||||
self.concat_mode = concat_mode
|
||||
self.cond_stage_trainable = cond_stage_trainable
|
||||
self.cond_stage_key = cond_stage_key
|
||||
@ -1418,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
|
||||
# TODO: move all layout-specific hacks to this class
|
||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
||||
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
||||
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
||||
|
||||
def log_images(self, batch, N=8, *args, **kwargs):
|
||||
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
||||
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
||||
|
||||
key = 'train' if self.training else 'validation'
|
||||
dset = self.trainer.datamodule.datasets[key]
|
||||
|
@ -644,13 +644,17 @@ class SwinIR(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
||||
embed_dim=96, depths=None, num_heads=None,
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
||||
**kwargs):
|
||||
super(SwinIR, self).__init__()
|
||||
|
||||
depths = depths or [6, 6, 6, 6]
|
||||
num_heads = num_heads or [6, 6, 6, 6]
|
||||
|
||||
num_in_ch = in_chans
|
||||
num_out_ch = in_chans
|
||||
num_feat = 64
|
||||
|
@ -74,9 +74,12 @@ class WindowAttention(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
||||
pretrained_window_size=[0, 0]):
|
||||
pretrained_window_size=None):
|
||||
|
||||
super().__init__()
|
||||
|
||||
pretrained_window_size = pretrained_window_size or [0, 0]
|
||||
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.pretrained_window_size = pretrained_window_size
|
||||
@ -698,13 +701,17 @@ class Swin2SR(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
||||
embed_dim=96, depths=None, num_heads=None,
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
||||
**kwargs):
|
||||
super(Swin2SR, self).__init__()
|
||||
|
||||
depths = depths or [6, 6, 6, 6]
|
||||
num_heads = num_heads or [6, 6, 6, 6]
|
||||
|
||||
num_in_ch = in_chans
|
||||
num_out_ch = in_chans
|
||||
num_feat = 64
|
||||
|
@ -34,14 +34,16 @@ import piexif.helper
|
||||
def upscaler_to_index(name: str):
|
||||
try:
|
||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
|
||||
|
||||
|
||||
def script_name_to_index(name, scripts):
|
||||
try:
|
||||
return [script.title().lower() for script in scripts].index(name.lower())
|
||||
except Exception:
|
||||
raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
|
||||
|
||||
|
||||
def validate_sampler_name(name):
|
||||
config = sd_samplers.all_samplers_map.get(name, None)
|
||||
@ -50,20 +52,23 @@ def validate_sampler_name(name):
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def setUpscalers(req: dict):
|
||||
reqDict = vars(req)
|
||||
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
||||
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
||||
return reqDict
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";")[1].split(",")[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
||||
|
||||
|
||||
def encode_pil_to_base64(image):
|
||||
with io.BytesIO() as output_bytes:
|
||||
@ -94,6 +99,7 @@ def encode_pil_to_base64(image):
|
||||
|
||||
return base64.b64encode(bytes_data)
|
||||
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
rich_available = True
|
||||
try:
|
||||
|
@ -161,10 +161,13 @@ class Fuse_sft_block(nn.Module):
|
||||
class CodeFormer(VQAutoEncoder):
|
||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||
codebook_size=1024, latent_size=256,
|
||||
connect_list=['32', '64', '128', '256'],
|
||||
fix_modules=['quantize','generator']):
|
||||
connect_list=None,
|
||||
fix_modules=None):
|
||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||
|
||||
connect_list = connect_list or ['32', '64', '128', '256']
|
||||
fix_modules = fix_modules or ['quantize', 'generator']
|
||||
|
||||
if fix_modules is not None:
|
||||
for module in fix_modules:
|
||||
for param in getattr(self, module).parameters():
|
||||
|
@ -326,7 +326,7 @@ class Generator(nn.Module):
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQAutoEncoder(nn.Module):
|
||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
|
||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
||||
super().__init__()
|
||||
logger = get_root_logger()
|
||||
@ -337,7 +337,7 @@ class VQAutoEncoder(nn.Module):
|
||||
self.embed_dim = emb_dim
|
||||
self.ch_mult = ch_mult
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.attn_resolutions = attn_resolutions or [16]
|
||||
self.quantizer_type = quantizer
|
||||
self.encoder = Encoder(
|
||||
self.in_channels,
|
||||
|
@ -19,14 +19,14 @@ registered_param_bindings = []
|
||||
|
||||
|
||||
class ParamBinding:
|
||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
|
||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
|
||||
self.paste_button = paste_button
|
||||
self.tabname = tabname
|
||||
self.source_text_component = source_text_component
|
||||
self.source_image_component = source_image_component
|
||||
self.source_tabname = source_tabname
|
||||
self.override_settings_component = override_settings_component
|
||||
self.paste_field_names = paste_field_names
|
||||
self.paste_field_names = paste_field_names or []
|
||||
|
||||
|
||||
def reset():
|
||||
|
@ -52,7 +52,7 @@ class DDPM(pl.LightningModule):
|
||||
beta_schedule="linear",
|
||||
loss_type="l2",
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
ignore_keys=None,
|
||||
load_only_unet=False,
|
||||
monitor="val/loss",
|
||||
use_ema=True,
|
||||
@ -107,7 +107,7 @@ class DDPM(pl.LightningModule):
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
||||
|
||||
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
||||
if self.use_ema and not load_ema:
|
||||
@ -194,7 +194,9 @@ class DDPM(pl.LightningModule):
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
||||
ignore_keys = ignore_keys or []
|
||||
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
if "state_dict" in list(sd.keys()):
|
||||
sd = sd["state_dict"]
|
||||
@ -473,7 +475,7 @@ class LatentDiffusion(DDPM):
|
||||
conditioning_key = None
|
||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||
super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
|
||||
super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
|
||||
self.concat_mode = concat_mode
|
||||
self.cond_stage_trainable = cond_stage_trainable
|
||||
self.cond_stage_key = cond_stage_key
|
||||
@ -1433,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion):
|
||||
# TODO: move all layout-specific hacks to this class
|
||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
||||
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
||||
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
||||
|
||||
def log_images(self, batch, N=8, *args, **kwargs):
|
||||
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
||||
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
||||
|
||||
key = 'train' if self.training else 'validation'
|
||||
dset = self.trainer.datamodule.datasets[key]
|
||||
|
@ -178,13 +178,13 @@ def model_wrapper(
|
||||
model,
|
||||
noise_schedule,
|
||||
model_type="noise",
|
||||
model_kwargs={},
|
||||
model_kwargs=None,
|
||||
guidance_type="uncond",
|
||||
#condition=None,
|
||||
#unconditional_condition=None,
|
||||
guidance_scale=1.,
|
||||
classifier_fn=None,
|
||||
classifier_kwargs={},
|
||||
classifier_kwargs=None,
|
||||
):
|
||||
"""Create a wrapper function for the noise prediction model.
|
||||
|
||||
@ -275,6 +275,9 @@ def model_wrapper(
|
||||
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
||||
"""
|
||||
|
||||
model_kwargs = model_kwargs or []
|
||||
classifier_kwargs = classifier_kwargs or []
|
||||
|
||||
def get_model_input_time(t_continuous):
|
||||
"""
|
||||
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
||||
|
@ -104,7 +104,7 @@ def check_pt(filename, extra_handler):
|
||||
|
||||
|
||||
def load(filename, *args, **kwargs):
|
||||
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
|
||||
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
|
||||
|
||||
|
||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||
|
@ -55,7 +55,7 @@ class VanillaStableDiffusionSampler:
|
||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
||||
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
|
||||
|
||||
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
||||
|
||||
|
@ -17,7 +17,7 @@ class EmbeddingEncoder(json.JSONEncoder):
|
||||
|
||||
class EmbeddingDecoder(json.JSONDecoder):
|
||||
def __init__(self, *args, **kwargs):
|
||||
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
|
||||
json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
|
||||
|
||||
def object_hook(self, d):
|
||||
if 'TORCHTENSOR' in d:
|
||||
|
@ -32,8 +32,8 @@ class LearnScheduleIterator:
|
||||
self.maxit += 1
|
||||
return
|
||||
assert self.rates
|
||||
except (ValueError, AssertionError):
|
||||
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
|
||||
except (ValueError, AssertionError) as e:
|
||||
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -24,6 +24,9 @@ ignore = [
|
||||
|
||||
]
|
||||
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"webui.py" = ["E402"] # Module level import not at top of file
|
||||
|
||||
[tool.ruff.flake8-bugbear]
|
||||
# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
|
||||
extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"]
|
Loading…
Reference in New Issue
Block a user