gfpgan: just download the damn model
This commit is contained in:
parent
d6fd71f36f
commit
d4205e66fa
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
from modules import shared, devices
|
from modules import shared, devices
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
@ -11,14 +12,20 @@ import modules.face_restoration
|
|||||||
def gfpgan_model_path():
|
def gfpgan_model_path():
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
|
|
||||||
|
filemask = 'GFPGAN*.pth'
|
||||||
|
|
||||||
|
if cmd_opts.gfpgan_model is not None:
|
||||||
|
return cmd_opts.gfpgan_model
|
||||||
|
|
||||||
places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
|
places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
|
||||||
files = [cmd_opts.gfpgan_model] + [os.path.join(dirname, cmd_opts.gfpgan_model) for dirname in places]
|
|
||||||
found = [x for x in files if os.path.exists(x)]
|
|
||||||
|
|
||||||
if len(found) == 0:
|
filename = None
|
||||||
raise Exception("GFPGAN model not found in paths: " + ", ".join(files))
|
for place in places:
|
||||||
|
filename = next(iter(glob(os.path.join(place, filemask))), None)
|
||||||
|
if filename is not None:
|
||||||
|
break
|
||||||
|
|
||||||
return found[0]
|
return filename
|
||||||
|
|
||||||
|
|
||||||
loaded_gfpgan_model = None
|
loaded_gfpgan_model = None
|
||||||
@ -34,7 +41,7 @@ def gfpgan():
|
|||||||
if gfpgan_constructor is None:
|
if gfpgan_constructor is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
model = gfpgan_constructor(model_path=gfpgan_model_path(), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
||||||
model.gfpgan.to(shared.device)
|
model.gfpgan.to(shared.device)
|
||||||
loaded_gfpgan_model = model
|
loaded_gfpgan_model = model
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@ import sys
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from glob import glob
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
@ -22,7 +21,7 @@ parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs
|
|||||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
|
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
|
||||||
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
|
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
|
||||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=next(iter(glob('GFPGAN*.pth')), ''))
|
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||||
|
Loading…
Reference in New Issue
Block a user