add swinir v2 support
This commit is contained in:
parent
ece27fe989
commit
ed769977f0
@ -10,6 +10,7 @@ from tqdm import tqdm
|
|||||||
from modules import modelloader
|
from modules import modelloader
|
||||||
from modules.shared import cmd_opts, opts, device
|
from modules.shared import cmd_opts, opts, device
|
||||||
from modules.swinir_model_arch import SwinIR as net
|
from modules.swinir_model_arch import SwinIR as net
|
||||||
|
from modules.swinir_model_arch_v2 import Swin2SR as net2
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
|
||||||
precision_scope = (
|
precision_scope = (
|
||||||
@ -57,22 +58,42 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
filename = path
|
filename = path
|
||||||
if filename is None or not os.path.exists(filename):
|
if filename is None or not os.path.exists(filename):
|
||||||
return None
|
return None
|
||||||
model = net(
|
if filename.endswith(".v2.pth"):
|
||||||
|
model = net2(
|
||||||
upscale=scale,
|
upscale=scale,
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
img_size=64,
|
img_size=64,
|
||||||
window_size=8,
|
window_size=8,
|
||||||
img_range=1.0,
|
img_range=1.0,
|
||||||
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
depths=[6, 6, 6, 6, 6, 6],
|
||||||
embed_dim=240,
|
embed_dim=180,
|
||||||
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
num_heads=[6, 6, 6, 6, 6, 6],
|
||||||
mlp_ratio=2,
|
mlp_ratio=2,
|
||||||
upsampler="nearest+conv",
|
upsampler="nearest+conv",
|
||||||
resi_connection="3conv",
|
resi_connection="1conv",
|
||||||
)
|
)
|
||||||
|
params = None
|
||||||
|
else:
|
||||||
|
model = net(
|
||||||
|
upscale=scale,
|
||||||
|
in_chans=3,
|
||||||
|
img_size=64,
|
||||||
|
window_size=8,
|
||||||
|
img_range=1.0,
|
||||||
|
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
||||||
|
embed_dim=240,
|
||||||
|
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
||||||
|
mlp_ratio=2,
|
||||||
|
upsampler="nearest+conv",
|
||||||
|
resi_connection="3conv",
|
||||||
|
)
|
||||||
|
params = "params_ema"
|
||||||
|
|
||||||
pretrained_model = torch.load(filename)
|
pretrained_model = torch.load(filename)
|
||||||
model.load_state_dict(pretrained_model["params_ema"], strict=True)
|
if params is not None:
|
||||||
|
model.load_state_dict(pretrained_model[params], strict=True)
|
||||||
|
else:
|
||||||
|
model.load_state_dict(pretrained_model, strict=True)
|
||||||
if not cmd_opts.no_half:
|
if not cmd_opts.no_half:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
return model
|
return model
|
||||||
|
Loading…
Reference in New Issue
Block a user