remove unwanted formatting/functionality from the PR
This commit is contained in:
parent
2552204fcb
commit
d1f098540a
@ -1,5 +1,4 @@
|
||||
# this scripts installs necessary requirements and launches main program in webui.py
|
||||
import shutil
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
@ -119,11 +118,7 @@ git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming
|
||||
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
if os.path.isdir(repo_dir('latent-diffusion')):
|
||||
try:
|
||||
shutil.rmtree(repo_dir('latent-diffusion'))
|
||||
except:
|
||||
pass
|
||||
|
||||
if not is_installed("lpips"):
|
||||
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
|
||||
|
||||
|
@ -13,60 +13,13 @@ from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class UpscalerESRGAN(Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "ESRGAN"
|
||||
self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
|
||||
self.model_name = "ESRGAN 4x"
|
||||
self.scalers = []
|
||||
self.user_path = dirname
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
super().__init__()
|
||||
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
||||
scalers = []
|
||||
if len(model_paths) == 0:
|
||||
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
for file in model_paths:
|
||||
print(f"File: {file}")
|
||||
if "http" in file:
|
||||
name = self.model_name
|
||||
else:
|
||||
name = modelloader.friendly_name(file)
|
||||
|
||||
scaler_data = UpscalerData(name, file, self, 4)
|
||||
print(f"ESRGAN: Adding scaler {name}")
|
||||
self.scalers.append(scaler_data)
|
||||
|
||||
def do_upscale(self, img, selected_model):
|
||||
model = self.load_model(selected_model)
|
||||
if model is None:
|
||||
return img
|
||||
model.to(shared.device)
|
||||
img = esrgan_upscale(model, img)
|
||||
return img
|
||||
|
||||
def load_model(self, path: str):
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
||||
file_name="%s.pth" % self.model_name,
|
||||
progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print("Unable to load %s from %s" % (self.model_path, filename))
|
||||
return None
|
||||
def fix_model_layers(crt_model, pretrained_net):
|
||||
# this code is adapted from https://github.com/xinntao/ESRGAN
|
||||
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
|
||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
||||
|
||||
if 'conv_first.weight' in pretrained_net:
|
||||
crt_model.load_state_dict(pretrained_net)
|
||||
return crt_model
|
||||
return pretrained_net
|
||||
|
||||
if 'model.0.weight' not in pretrained_net:
|
||||
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net[
|
||||
"params_ema"]
|
||||
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
|
||||
if is_realesrgan:
|
||||
raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
|
||||
else:
|
||||
@ -115,8 +68,57 @@ class UpscalerESRGAN(Upscaler):
|
||||
crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
|
||||
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
|
||||
|
||||
crt_model.load_state_dict(crt_net)
|
||||
return crt_net
|
||||
|
||||
class UpscalerESRGAN(Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "ESRGAN"
|
||||
self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
|
||||
self.model_name = "ESRGAN 4x"
|
||||
self.scalers = []
|
||||
self.user_path = dirname
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
super().__init__()
|
||||
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
||||
scalers = []
|
||||
if len(model_paths) == 0:
|
||||
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
for file in model_paths:
|
||||
if "http" in file:
|
||||
name = self.model_name
|
||||
else:
|
||||
name = modelloader.friendly_name(file)
|
||||
|
||||
scaler_data = UpscalerData(name, file, self, 4)
|
||||
self.scalers.append(scaler_data)
|
||||
|
||||
def do_upscale(self, img, selected_model):
|
||||
model = self.load_model(selected_model)
|
||||
if model is None:
|
||||
return img
|
||||
model.to(shared.device)
|
||||
img = esrgan_upscale(model, img)
|
||||
return img
|
||||
|
||||
def load_model(self, path: str):
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
||||
file_name="%s.pth" % self.model_name,
|
||||
progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print("Unable to load %s from %s" % (self.model_path, filename))
|
||||
return None
|
||||
|
||||
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
|
||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
||||
|
||||
pretrained_net = fix_model_layers(crt_model, pretrained_net)
|
||||
crt_model.load_state_dict(pretrained_net)
|
||||
crt_model.eval()
|
||||
|
||||
return crt_model
|
||||
|
||||
|
||||
@ -154,7 +156,6 @@ def esrgan_upscale(model, img):
|
||||
newrow.append([x * scale_factor, w * scale_factor, output])
|
||||
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
||||
|
||||
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor,
|
||||
grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
||||
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
||||
output = images.combine_grid(newgrid)
|
||||
return output
|
||||
|
@ -67,6 +67,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
|
||||
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
||||
image = res
|
||||
|
||||
if upscaling_resize != 1.0:
|
||||
def upscale(image, scaler_index, resize):
|
||||
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
||||
pixels = tuple(np.array(small).flatten().tolist())
|
||||
|
@ -36,8 +36,7 @@ def gfpgann():
|
||||
else:
|
||||
print("Unable to load gfpgan model!")
|
||||
return None
|
||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2,
|
||||
bg_upsampler=None)
|
||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
||||
model.gfpgan.to(shared.device)
|
||||
loaded_gfpgan_model = model
|
||||
|
||||
@ -49,8 +48,7 @@ def gfpgan_fix_faces(np_image):
|
||||
if model is None:
|
||||
return np_image
|
||||
np_image_bgr = np_image[:, :, ::-1]
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False,
|
||||
only_center_face=False, paste_back=True)
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
||||
|
||||
if shared.opts.face_restoration_unload:
|
||||
@ -79,7 +77,6 @@ def setup_model(dirname):
|
||||
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
||||
|
||||
def my_load_file_from_url(**kwargs):
|
||||
print("Setting model_dir to " + model_path)
|
||||
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
|
||||
|
||||
def facex_load_file_from_url(**kwargs):
|
||||
@ -92,7 +89,6 @@ def setup_model(dirname):
|
||||
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
||||
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
||||
user_path = dirname
|
||||
print("Have gfpgan should be true?")
|
||||
have_gfpgan = True
|
||||
gfpgan_constructor = GFPGANer
|
||||
|
||||
@ -102,9 +98,7 @@ def setup_model(dirname):
|
||||
|
||||
def restore(self, np_image):
|
||||
np_image_bgr = np_image[:, :, ::-1]
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True)
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
||||
|
||||
return np_image
|
||||
|
@ -84,10 +84,8 @@ def combine_grid(grid):
|
||||
r = r.astype(np.uint8)
|
||||
return Image.fromarray(r, 'L')
|
||||
|
||||
mask_w = make_mask_image(
|
||||
np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
||||
mask_h = make_mask_image(
|
||||
np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
||||
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
||||
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
||||
|
||||
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
|
||||
for y, h, row in grid.tiles:
|
||||
@ -130,12 +128,10 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
||||
|
||||
def draw_texts(drawing, draw_x, draw_y, lines):
|
||||
for i, line in enumerate(lines):
|
||||
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt,
|
||||
fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
||||
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
||||
|
||||
if not line.is_active:
|
||||
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2,
|
||||
draw_y + line.size[1] // 2), fill=color_inactive, width=4)
|
||||
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
|
||||
|
||||
draw_y += line.size[1] + line_spacing
|
||||
|
||||
@ -206,10 +202,8 @@ def draw_prompt_matrix(im, width, height, all_prompts):
|
||||
prompts_horiz = prompts[:boundary]
|
||||
prompts_vert = prompts[boundary:]
|
||||
|
||||
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in
|
||||
range(1 << len(prompts_horiz))]
|
||||
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in
|
||||
range(1 << len(prompts_vert))]
|
||||
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
||||
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
||||
|
||||
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
|
||||
|
||||
@ -259,13 +253,11 @@ def resize_image(resize_mode, im, width, height):
|
||||
if ratio < src_ratio:
|
||||
fill_height = height // 2 - src_h // 2
|
||||
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
||||
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
|
||||
box=(0, fill_height + src_h))
|
||||
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
||||
elif ratio > src_ratio:
|
||||
fill_width = width // 2 - src_w // 2
|
||||
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
||||
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
|
||||
box=(fill_width + src_w, 0))
|
||||
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
||||
|
||||
return res
|
||||
|
||||
@ -300,8 +292,7 @@ def apply_filename_pattern(x, p, seed, prompt):
|
||||
words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
|
||||
if len(words) == 0:
|
||||
words = ["empty"]
|
||||
x = x.replace("[prompt_words]",
|
||||
sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
|
||||
x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
|
||||
|
||||
if p is not None:
|
||||
x = x.replace("[steps]", str(p.steps))
|
||||
@ -309,8 +300,7 @@ def apply_filename_pattern(x, p, seed, prompt):
|
||||
x = x.replace("[width]", str(p.width))
|
||||
x = x.replace("[height]", str(p.height))
|
||||
x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False))
|
||||
x = x.replace("[sampler]",
|
||||
sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
|
||||
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
|
||||
|
||||
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
|
||||
x = x.replace("[date]", datetime.date.today().isoformat())
|
||||
@ -336,8 +326,7 @@ def get_next_sequence_number(path, basename):
|
||||
prefix_length = len(basename)
|
||||
for p in os.listdir(path):
|
||||
if p.startswith(basename):
|
||||
l = os.path.splitext(p[prefix_length:])[0].split(
|
||||
'-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||
try:
|
||||
result = max(int(l[0]), result)
|
||||
except ValueError:
|
||||
@ -346,9 +335,7 @@ def get_next_sequence_number(path, basename):
|
||||
return result + 1
|
||||
|
||||
|
||||
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False,
|
||||
no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None,
|
||||
forced_filename=None, suffix=""):
|
||||
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""):
|
||||
if short_filename or prompt is None or seed is None:
|
||||
file_decoration = ""
|
||||
elif opts.save_to_dirs:
|
||||
|
@ -125,7 +125,6 @@ class LDSR:
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
print(f'Processing finished!')
|
||||
return a
|
||||
|
||||
|
||||
|
@ -25,8 +25,10 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
||||
|
||||
if ext_filter is None:
|
||||
ext_filter = []
|
||||
|
||||
try:
|
||||
places = []
|
||||
|
||||
if command_path is not None and command_path != model_path:
|
||||
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
|
||||
if os.path.exists(pretrained_path):
|
||||
@ -34,7 +36,9 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
||||
places.append(pretrained_path)
|
||||
elif os.path.exists(command_path):
|
||||
places.append(command_path)
|
||||
|
||||
places.append(model_path)
|
||||
|
||||
for place in places:
|
||||
if os.path.exists(place):
|
||||
for file in os.listdir(place):
|
||||
@ -47,14 +51,17 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
||||
continue
|
||||
if file not in output:
|
||||
output.append(full_path)
|
||||
|
||||
if model_url is not None and len(output) == 0:
|
||||
if download_name is not None:
|
||||
dl = load_file_from_url(model_url, model_path, True, download_name)
|
||||
output.append(dl)
|
||||
else:
|
||||
output.append(model_url)
|
||||
except:
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
@ -88,28 +88,24 @@ def get_realesrgan_models(scaler):
|
||||
models = [
|
||||
UpscalerData(
|
||||
name="R-ESRGAN General 4xV3",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3"
|
||||
".pth",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4,
|
||||
act_type='prelu')
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN General WDN 4xV3",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4,
|
||||
act_type='prelu')
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN AnimeVideo",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4,
|
||||
act_type='prelu')
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 4x+",
|
||||
|
@ -12,10 +12,10 @@ from modules import shared, modelloader
|
||||
from modules.paths import models_path
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
model_name = "sd-v1-4.ckpt"
|
||||
model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
|
||||
user_dir = None
|
||||
user_dir: (str | None) = None
|
||||
|
||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
||||
checkpoints_list = {}
|
||||
@ -30,26 +30,8 @@ except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def modeltitle(path, h):
|
||||
abspath = os.path.abspath(path)
|
||||
|
||||
if abspath.startswith(model_dir):
|
||||
name = abspath.replace(model_dir, '')
|
||||
else:
|
||||
name = os.path.basename(path)
|
||||
|
||||
if name.startswith("\\") or name.startswith("/"):
|
||||
name = name[1:]
|
||||
|
||||
return f'{name} [{h}]'
|
||||
|
||||
|
||||
def setup_model(dirname):
|
||||
global model_path
|
||||
global model_name
|
||||
global model_url
|
||||
global user_dir
|
||||
global model_list
|
||||
user_dir = dirname
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
@ -62,21 +44,16 @@ def checkpoint_tiles():
|
||||
|
||||
|
||||
def list_models():
|
||||
global model_path
|
||||
global model_url
|
||||
global model_name
|
||||
global user_dir
|
||||
checkpoints_list.clear()
|
||||
model_list = modelloader.load_models(model_path=model_path,model_url=model_url,command_path= user_dir,
|
||||
ext_filter=[".ckpt"], download_name=model_name)
|
||||
print(f"Model list: {model_list}")
|
||||
model_dir = os.path.abspath(model_path)
|
||||
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name)
|
||||
|
||||
def modeltitle(path, h):
|
||||
def modeltitle(path, shorthash):
|
||||
abspath = os.path.abspath(path)
|
||||
|
||||
if abspath.startswith(model_dir):
|
||||
name = abspath.replace(model_dir, '')
|
||||
if user_dir is not None and abspath.startswith(user_dir):
|
||||
name = abspath.replace(user_dir, '')
|
||||
elif abspath.startswith(model_path):
|
||||
name = abspath.replace(model_path, '')
|
||||
else:
|
||||
name = os.path.basename(path)
|
||||
|
||||
@ -85,19 +62,20 @@ def list_models():
|
||||
|
||||
shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||
|
||||
return f'{name} [{h}]', shortname
|
||||
return f'{name} [{shorthash}]', shortname
|
||||
|
||||
cmd_ckpt = shared.cmd_opts.ckpt
|
||||
if os.path.exists(cmd_ckpt):
|
||||
h = model_hash(cmd_ckpt)
|
||||
title, model_name = modeltitle(cmd_ckpt, h)
|
||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
|
||||
title, short_model_name = modeltitle(cmd_ckpt, h)
|
||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||
for filename in model_list:
|
||||
h = model_hash(filename)
|
||||
title, model_name = modeltitle(filename, h)
|
||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)
|
||||
title, short_model_name = modeltitle(filename, h)
|
||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
|
||||
|
||||
|
||||
def get_closet_checkpoint_match(searchString):
|
||||
applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
|
||||
@ -105,9 +83,9 @@ def get_closet_checkpoint_match(searchString):
|
||||
return applicable[0]
|
||||
return None
|
||||
|
||||
|
||||
def model_hash(filename):
|
||||
try:
|
||||
print(f"Opening: {filename}")
|
||||
with open(filename, "rb") as file:
|
||||
import hashlib
|
||||
m = hashlib.sha256()
|
||||
@ -128,7 +106,7 @@ def select_checkpoint():
|
||||
if len(checkpoints_list) == 0:
|
||||
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
|
||||
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
|
||||
print(f" - directory {os.path.abspath(shared.cmd_opts.stablediffusion_models_path)}", file=sys.stderr)
|
||||
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
||||
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
|
@ -21,8 +21,7 @@ model_path = os.path.join(script_path, 'models')
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
|
||||
# This should be deprecated, but we'll leave it for a few iterations
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints (Deprecated, use '--stablediffusion-models-path'", )
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
@ -41,7 +40,6 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
|
||||
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
|
||||
parser.add_argument("--stablediffusion-models-path", type=str, help="Path to directory with Stable-diffusion checkpoints.", default=os.path.join(model_path, 'SwinIR'))
|
||||
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
|
||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR'))
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
|
||||
@ -61,10 +59,6 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
|
||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
|
||||
cmd_opts = parser.parse_args()
|
||||
if cmd_opts.ckpt_dir is not None:
|
||||
print("The 'ckpt-dir' arg is deprecated in favor of the 'stablediffusion-models-path' argument and will be "
|
||||
"removed in a future release. Please use the new option if you wish to use a custom checkpoint directory.")
|
||||
cmd_opts.__setattr__("stablediffusion-models-path", cmd_opts.ckpt_dir)
|
||||
device = get_optimal_device()
|
||||
|
||||
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
||||
|
2
webui.py
2
webui.py
@ -28,7 +28,7 @@ from modules.paths import script_path
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
modelloader.cleanup_models()
|
||||
modules.sd_models.setup_model(cmd_opts.stablediffusion_models_path)
|
||||
modules.sd_models.setup_model(cmd_opts.ckpt_dir)
|
||||
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
||||
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
||||
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
||||
|
Loading…
Reference in New Issue
Block a user