remove unwanted formatting/functionality from the PR

This commit is contained in:
AUTOMATIC 2022-09-30 11:42:40 +03:00
parent 2552204fcb
commit d1f098540a
11 changed files with 127 additions and 175 deletions

View File

@ -1,5 +1,4 @@
# this scripts installs necessary requirements and launches main program in webui.py # this scripts installs necessary requirements and launches main program in webui.py
import shutil
import subprocess import subprocess
import os import os
import sys import sys
@ -119,11 +118,7 @@ git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
if os.path.isdir(repo_dir('latent-diffusion')):
try:
shutil.rmtree(repo_dir('latent-diffusion'))
except:
pass
if not is_installed("lpips"): if not is_installed("lpips"):
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")

View File

@ -13,6 +13,63 @@ from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts from modules.shared import opts
def fix_model_layers(crt_model, pretrained_net):
# this code is adapted from https://github.com/xinntao/ESRGAN
if 'conv_first.weight' in pretrained_net:
return pretrained_net
if 'model.0.weight' not in pretrained_net:
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
if is_realesrgan:
raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
else:
raise Exception("The file is not a ESRGAN model.")
crt_net = crt_model.state_dict()
load_net_clean = {}
for k, v in pretrained_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
else:
load_net_clean[k] = v
pretrained_net = load_net_clean
tbd = []
for k, v in crt_net.items():
tbd.append(k)
# directly copy
for k, v in crt_net.items():
if k in pretrained_net and pretrained_net[k].size() == v.size():
crt_net[k] = pretrained_net[k]
tbd.remove(k)
crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
for k in tbd.copy():
if 'RDB' in k:
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
if '.weight' in k:
ori_k = ori_k.replace('.weight', '.0.weight')
elif '.bias' in k:
ori_k = ori_k.replace('.bias', '.0.bias')
crt_net[k] = pretrained_net[ori_k]
tbd.remove(k)
crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
return crt_net
class UpscalerESRGAN(Upscaler): class UpscalerESRGAN(Upscaler):
def __init__(self, dirname): def __init__(self, dirname):
self.name = "ESRGAN" self.name = "ESRGAN"
@ -28,14 +85,12 @@ class UpscalerESRGAN(Upscaler):
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4) scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data) scalers.append(scaler_data)
for file in model_paths: for file in model_paths:
print(f"File: {file}")
if "http" in file: if "http" in file:
name = self.model_name name = self.model_name
else: else:
name = modelloader.friendly_name(file) name = modelloader.friendly_name(file)
scaler_data = UpscalerData(name, file, self, 4) scaler_data = UpscalerData(name, file, self, 4)
print(f"ESRGAN: Adding scaler {name}")
self.scalers.append(scaler_data) self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model): def do_upscale(self, img, selected_model):
@ -56,67 +111,14 @@ class UpscalerESRGAN(Upscaler):
if not os.path.exists(filename) or filename is None: if not os.path.exists(filename) or filename is None:
print("Unable to load %s from %s" % (self.model_path, filename)) print("Unable to load %s from %s" % (self.model_path, filename))
return None return None
# this code is adapted from https://github.com/xinntao/ESRGAN
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None) pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
if 'conv_first.weight' in pretrained_net: pretrained_net = fix_model_layers(crt_model, pretrained_net)
crt_model.load_state_dict(pretrained_net) crt_model.load_state_dict(pretrained_net)
return crt_model
if 'model.0.weight' not in pretrained_net:
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net[
"params_ema"]
if is_realesrgan:
raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
else:
raise Exception("The file is not a ESRGAN model.")
crt_net = crt_model.state_dict()
load_net_clean = {}
for k, v in pretrained_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
else:
load_net_clean[k] = v
pretrained_net = load_net_clean
tbd = []
for k, v in crt_net.items():
tbd.append(k)
# directly copy
for k, v in crt_net.items():
if k in pretrained_net and pretrained_net[k].size() == v.size():
crt_net[k] = pretrained_net[k]
tbd.remove(k)
crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
for k in tbd.copy():
if 'RDB' in k:
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
if '.weight' in k:
ori_k = ori_k.replace('.weight', '.0.weight')
elif '.bias' in k:
ori_k = ori_k.replace('.bias', '.0.bias')
crt_net[k] = pretrained_net[ori_k]
tbd.remove(k)
crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
crt_model.load_state_dict(crt_net)
crt_model.eval() crt_model.eval()
return crt_model return crt_model
@ -154,7 +156,6 @@ def esrgan_upscale(model, img):
newrow.append([x * scale_factor, w * scale_factor, output]) newrow.append([x * scale_factor, w * scale_factor, output])
newtiles.append([y * scale_factor, h * scale_factor, newrow]) newtiles.append([y * scale_factor, h * scale_factor, newrow])
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
output = images.combine_grid(newgrid) output = images.combine_grid(newgrid)
return output return output

View File

@ -67,28 +67,29 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
image = res image = res
def upscale(image, scaler_index, resize): if upscaling_resize != 1.0:
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) def upscale(image, scaler_index, resize):
pixels = tuple(np.array(small).flatten().tolist()) small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels pixels = tuple(np.array(small).flatten().tolist())
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
c = cached_images.get(key) c = cached_images.get(key)
if c is None: if c is None:
upscaler = shared.sd_upscalers[scaler_index] upscaler = shared.sd_upscalers[scaler_index]
c = upscaler.scaler.upscale(image, resize, upscaler.data_path) c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
cached_images[key] = c cached_images[key] = c
return c return c
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n" info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
res = upscale(image, extras_upscaler_1, upscaling_resize) res = upscale(image, extras_upscaler_1, upscaling_resize)
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
res2 = upscale(image, extras_upscaler_2, upscaling_resize) res2 = upscale(image, extras_upscaler_2, upscaling_resize)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n" info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
res = Image.blend(res, res2, extras_upscaler_2_visibility) res = Image.blend(res, res2, extras_upscaler_2_visibility)
image = res image = res
while len(cached_images) > 2: while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))] del cached_images[next(iter(cached_images.keys()))]

View File

@ -36,8 +36,7 @@ def gfpgann():
else: else:
print("Unable to load gfpgan model!") print("Unable to load gfpgan model!")
return None return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
bg_upsampler=None)
model.gfpgan.to(shared.device) model.gfpgan.to(shared.device)
loaded_gfpgan_model = model loaded_gfpgan_model = model
@ -49,8 +48,7 @@ def gfpgan_fix_faces(np_image):
if model is None: if model is None:
return np_image return np_image
np_image_bgr = np_image[:, :, ::-1] np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1] np_image = gfpgan_output_bgr[:, :, ::-1]
if shared.opts.face_restoration_unload: if shared.opts.face_restoration_unload:
@ -79,7 +77,6 @@ def setup_model(dirname):
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
def my_load_file_from_url(**kwargs): def my_load_file_from_url(**kwargs):
print("Setting model_dir to " + model_path)
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path)) return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
def facex_load_file_from_url(**kwargs): def facex_load_file_from_url(**kwargs):
@ -92,7 +89,6 @@ def setup_model(dirname):
facexlib.detection.load_file_from_url = facex_load_file_from_url facexlib.detection.load_file_from_url = facex_load_file_from_url
facexlib.parsing.load_file_from_url = facex_load_file_from_url2 facexlib.parsing.load_file_from_url = facex_load_file_from_url2
user_path = dirname user_path = dirname
print("Have gfpgan should be true?")
have_gfpgan = True have_gfpgan = True
gfpgan_constructor = GFPGANer gfpgan_constructor = GFPGANer
@ -102,9 +98,7 @@ def setup_model(dirname):
def restore(self, np_image): def restore(self, np_image):
np_image_bgr = np_image[:, :, ::-1] np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
only_center_face=False,
paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1] np_image = gfpgan_output_bgr[:, :, ::-1]
return np_image return np_image

View File

@ -84,10 +84,8 @@ def combine_grid(grid):
r = r.astype(np.uint8) r = r.astype(np.uint8)
return Image.fromarray(r, 'L') return Image.fromarray(r, 'L')
mask_w = make_mask_image( mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)) mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
mask_h = make_mask_image(
np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
combined_image = Image.new("RGB", (grid.image_w, grid.image_h)) combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
for y, h, row in grid.tiles: for y, h, row in grid.tiles:
@ -130,12 +128,10 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
def draw_texts(drawing, draw_x, draw_y, lines): def draw_texts(drawing, draw_x, draw_y, lines):
for i, line in enumerate(lines): for i, line in enumerate(lines):
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
if not line.is_active: if not line.is_active:
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
draw_y + line.size[1] // 2), fill=color_inactive, width=4)
draw_y += line.size[1] + line_spacing draw_y += line.size[1] + line_spacing
@ -206,10 +202,8 @@ def draw_prompt_matrix(im, width, height, all_prompts):
prompts_horiz = prompts[:boundary] prompts_horiz = prompts[:boundary]
prompts_vert = prompts[boundary:] prompts_vert = prompts[boundary:]
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
range(1 << len(prompts_horiz))] ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in
range(1 << len(prompts_vert))]
return draw_grid_annotations(im, width, height, hor_texts, ver_texts) return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
@ -259,13 +253,11 @@ def resize_image(resize_mode, im, width, height):
if ratio < src_ratio: if ratio < src_ratio:
fill_height = height // 2 - src_h // 2 fill_height = height // 2 - src_h // 2
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
box=(0, fill_height + src_h))
elif ratio > src_ratio: elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2 fill_width = width // 2 - src_w // 2
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
box=(fill_width + src_w, 0))
return res return res
@ -300,8 +292,7 @@ def apply_filename_pattern(x, p, seed, prompt):
words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0] words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
if len(words) == 0: if len(words) == 0:
words = ["empty"] words = ["empty"]
x = x.replace("[prompt_words]", x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
if p is not None: if p is not None:
x = x.replace("[steps]", str(p.steps)) x = x.replace("[steps]", str(p.steps))
@ -309,8 +300,7 @@ def apply_filename_pattern(x, p, seed, prompt):
x = x.replace("[width]", str(p.width)) x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height)) x = x.replace("[height]", str(p.height))
x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False)) x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False))
x = x.replace("[sampler]", x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash) x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
x = x.replace("[date]", datetime.date.today().isoformat()) x = x.replace("[date]", datetime.date.today().isoformat())
@ -336,8 +326,7 @@ def get_next_sequence_number(path, basename):
prefix_length = len(basename) prefix_length = len(basename)
for p in os.listdir(path): for p in os.listdir(path):
if p.startswith(basename): if p.startswith(basename):
l = os.path.splitext(p[prefix_length:])[0].split( l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
'-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
try: try:
result = max(int(l[0]), result) result = max(int(l[0]), result)
except ValueError: except ValueError:
@ -346,9 +335,7 @@ def get_next_sequence_number(path, basename):
return result + 1 return result + 1
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""):
no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None,
forced_filename=None, suffix=""):
if short_filename or prompt is None or seed is None: if short_filename or prompt is None or seed is None:
file_decoration = "" file_decoration = ""
elif opts.save_to_dirs: elif opts.save_to_dirs:

View File

@ -125,7 +125,6 @@ class LDSR:
del model del model
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
print(f'Processing finished!')
return a return a

View File

@ -25,8 +25,10 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if ext_filter is None: if ext_filter is None:
ext_filter = [] ext_filter = []
try: try:
places = [] places = []
if command_path is not None and command_path != model_path: if command_path is not None and command_path != model_path:
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models') pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
if os.path.exists(pretrained_path): if os.path.exists(pretrained_path):
@ -34,7 +36,9 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
places.append(pretrained_path) places.append(pretrained_path)
elif os.path.exists(command_path): elif os.path.exists(command_path):
places.append(command_path) places.append(command_path)
places.append(model_path) places.append(model_path)
for place in places: for place in places:
if os.path.exists(place): if os.path.exists(place):
for file in os.listdir(place): for file in os.listdir(place):
@ -47,14 +51,17 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
continue continue
if file not in output: if file not in output:
output.append(full_path) output.append(full_path)
if model_url is not None and len(output) == 0: if model_url is not None and len(output) == 0:
if download_name is not None: if download_name is not None:
dl = load_file_from_url(model_url, model_path, True, download_name) dl = load_file_from_url(model_url, model_path, True, download_name)
output.append(dl) output.append(dl)
else: else:
output.append(model_url) output.append(model_url)
except:
except Exception:
pass pass
return output return output

View File

@ -88,28 +88,24 @@ def get_realesrgan_models(scaler):
models = [ models = [
UpscalerData( UpscalerData(
name="R-ESRGAN General 4xV3", name="R-ESRGAN General 4xV3",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3" path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
".pth",
scale=4, scale=4,
upscaler=scaler, upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
act_type='prelu')
), ),
UpscalerData( UpscalerData(
name="R-ESRGAN General WDN 4xV3", name="R-ESRGAN General WDN 4xV3",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
scale=4, scale=4,
upscaler=scaler, upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
act_type='prelu')
), ),
UpscalerData( UpscalerData(
name="R-ESRGAN AnimeVideo", name="R-ESRGAN AnimeVideo",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
scale=4, scale=4,
upscaler=scaler, upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
act_type='prelu')
), ),
UpscalerData( UpscalerData(
name="R-ESRGAN 4x+", name="R-ESRGAN 4x+",

View File

@ -12,10 +12,10 @@ from modules import shared, modelloader
from modules.paths import models_path from modules.paths import models_path
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.join(models_path, model_dir) model_path = os.path.abspath(os.path.join(models_path, model_dir))
model_name = "sd-v1-4.ckpt" model_name = "sd-v1-4.ckpt"
model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1" model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
user_dir = None user_dir: (str | None) = None
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {} checkpoints_list = {}
@ -30,26 +30,8 @@ except Exception:
pass pass
def modeltitle(path, h):
abspath = os.path.abspath(path)
if abspath.startswith(model_dir):
name = abspath.replace(model_dir, '')
else:
name = os.path.basename(path)
if name.startswith("\\") or name.startswith("/"):
name = name[1:]
return f'{name} [{h}]'
def setup_model(dirname): def setup_model(dirname):
global model_path
global model_name
global model_url
global user_dir global user_dir
global model_list
user_dir = dirname user_dir = dirname
if not os.path.exists(model_path): if not os.path.exists(model_path):
os.makedirs(model_path) os.makedirs(model_path)
@ -62,21 +44,16 @@ def checkpoint_tiles():
def list_models(): def list_models():
global model_path
global model_url
global model_name
global user_dir
checkpoints_list.clear() checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path,model_url=model_url,command_path= user_dir, model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name)
ext_filter=[".ckpt"], download_name=model_name)
print(f"Model list: {model_list}")
model_dir = os.path.abspath(model_path)
def modeltitle(path, h): def modeltitle(path, shorthash):
abspath = os.path.abspath(path) abspath = os.path.abspath(path)
if abspath.startswith(model_dir): if user_dir is not None and abspath.startswith(user_dir):
name = abspath.replace(model_dir, '') name = abspath.replace(user_dir, '')
elif abspath.startswith(model_path):
name = abspath.replace(model_path, '')
else: else:
name = os.path.basename(path) name = os.path.basename(path)
@ -85,29 +62,30 @@ def list_models():
shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
return f'{name} [{h}]', shortname return f'{name} [{shorthash}]', shortname
cmd_ckpt = shared.cmd_opts.ckpt cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt): if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt) h = model_hash(cmd_ckpt)
title, model_name = modeltitle(cmd_ckpt, h) title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list: for filename in model_list:
h = model_hash(filename) h = model_hash(filename)
title, model_name = modeltitle(filename, h) title, short_model_name = modeltitle(filename, h)
checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name) checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
def get_closet_checkpoint_match(searchString): def get_closet_checkpoint_match(searchString):
applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title)) applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
if len(applicable)>0: if len(applicable) > 0:
return applicable[0] return applicable[0]
return None return None
def model_hash(filename): def model_hash(filename):
try: try:
print(f"Opening: {filename}")
with open(filename, "rb") as file: with open(filename, "rb") as file:
import hashlib import hashlib
m = hashlib.sha256() m = hashlib.sha256()
@ -128,7 +106,7 @@ def select_checkpoint():
if len(checkpoints_list) == 0: if len(checkpoints_list) == 0:
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
print(f" - directory {os.path.abspath(shared.cmd_opts.stablediffusion_models_path)}", file=sys.stderr) print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
exit(1) exit(1)

View File

@ -21,8 +21,7 @@ model_path = os.path.join(script_path, 'models')
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
# This should be deprecated, but we'll leave it for a few iterations parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints (Deprecated, use '--stablediffusion-models-path'", )
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
@ -41,7 +40,6 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN')) parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN')) parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN')) parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
parser.add_argument("--stablediffusion-models-path", type=str, help="Path to directory with Stable-diffusion checkpoints.", default=os.path.join(model_path, 'SwinIR'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR'))
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
@ -61,10 +59,6 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
if cmd_opts.ckpt_dir is not None:
print("The 'ckpt-dir' arg is deprecated in favor of the 'stablediffusion-models-path' argument and will be "
"removed in a future release. Please use the new option if you wish to use a custom checkpoint directory.")
cmd_opts.__setattr__("stablediffusion-models-path", cmd_opts.ckpt_dir)
device = get_optimal_device() device = get_optimal_device()
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)

View File

@ -28,7 +28,7 @@ from modules.paths import script_path
from modules.shared import cmd_opts from modules.shared import cmd_opts
modelloader.cleanup_models() modelloader.cleanup_models()
modules.sd_models.setup_model(cmd_opts.stablediffusion_models_path) modules.sd_models.setup_model(cmd_opts.ckpt_dir)
codeformer.setup_model(cmd_opts.codeformer_models_path) codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path) gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration()) shared.face_restorers.append(modules.face_restoration.FaceRestoration())