2022-09-26 09:29:50 -05:00
import contextlib
import os
2022-09-26 09:29:22 -05:00
import numpy as np
import torch
2022-09-26 09:29:50 -05:00
from PIL import Image
from basicsr . utils . download_util import load_file_from_url
2022-10-01 14:04:20 -04:00
from tqdm import tqdm
2022-09-26 09:29:50 -05:00
2022-12-03 18:06:33 +03:00
from modules import modelloader , devices , script_callbacks , shared
2023-01-23 21:50:59 -05:00
from modules . shared import cmd_opts , opts , state
2022-12-03 18:06:33 +03:00
from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2
2022-09-29 17:46:23 -05:00
from modules . upscaler import Upscaler , UpscalerData
2022-09-26 09:29:22 -05:00
2022-12-03 18:06:33 +03:00
device_swinir = devices . get_device_for ( ' swinir ' )
2022-09-29 17:46:23 -05:00
class UpscalerSwinIR ( Upscaler ) :
def __init__ ( self , dirname ) :
self . name = " SwinIR "
self . model_url = " https://github.com/JingyunLiang/SwinIR/releases/download/v0.0 " \
" /003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR " \
" -L_x4_GAN.pth "
self . model_name = " SwinIR 4x "
self . user_path = dirname
super ( ) . __init__ ( )
scalers = [ ]
model_files = self . find_models ( ext_filter = [ " .pt " , " .pth " ] )
for model in model_files :
if " http " in model :
name = self . model_name
else :
name = modelloader . friendly_name ( model )
model_data = UpscalerData ( name , model , self )
scalers . append ( model_data )
self . scalers = scalers
def do_upscale ( self , img , model_file ) :
model = self . load_model ( model_file )
if model is None :
return img
2022-12-03 18:06:33 +03:00
model = model . to ( device_swinir , dtype = devices . dtype )
2022-09-29 17:46:23 -05:00
img = upscale ( img , model )
try :
torch . cuda . empty_cache ( )
except :
pass
return img
2022-09-26 09:29:22 -05:00
2022-09-29 17:46:23 -05:00
def load_model ( self , path , scale = 4 ) :
if " http " in path :
2022-09-29 19:59:53 -05:00
dl_name = " %s %s " % ( self . model_name . replace ( " " , " _ " ) , " .pth " )
2022-09-29 17:46:23 -05:00
filename = load_file_from_url ( url = path , model_dir = self . model_path , file_name = dl_name , progress = True )
else :
filename = path
if filename is None or not os . path . exists ( filename ) :
return None
2022-10-09 18:56:59 +03:00
if filename . endswith ( " .v2.pth " ) :
model = net2 (
2022-09-29 17:46:23 -05:00
upscale = scale ,
in_chans = 3 ,
img_size = 64 ,
window_size = 8 ,
img_range = 1.0 ,
2022-10-09 18:56:59 +03:00
depths = [ 6 , 6 , 6 , 6 , 6 , 6 ] ,
embed_dim = 180 ,
num_heads = [ 6 , 6 , 6 , 6 , 6 , 6 ] ,
2022-09-29 17:46:23 -05:00
mlp_ratio = 2 ,
upsampler = " nearest+conv " ,
2022-10-09 18:56:59 +03:00
resi_connection = " 1conv " ,
)
params = None
else :
model = net (
upscale = scale ,
in_chans = 3 ,
img_size = 64 ,
window_size = 8 ,
img_range = 1.0 ,
depths = [ 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 ] ,
embed_dim = 240 ,
num_heads = [ 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 ] ,
mlp_ratio = 2 ,
upsampler = " nearest+conv " ,
resi_connection = " 3conv " ,
)
params = " params_ema "
2022-09-29 17:46:23 -05:00
pretrained_model = torch . load ( filename )
2022-10-09 18:56:59 +03:00
if params is not None :
model . load_state_dict ( pretrained_model [ params ] , strict = True )
else :
model . load_state_dict ( pretrained_model , strict = True )
2022-09-29 17:46:23 -05:00
return model
2022-09-26 09:29:22 -05:00
def upscale (
2022-09-29 17:46:23 -05:00
img ,
model ,
2022-12-03 20:40:11 +03:00
tile = None ,
tile_overlap = None ,
2022-09-29 17:46:23 -05:00
window_size = 8 ,
scale = 4 ,
2022-09-26 09:29:22 -05:00
) :
2022-12-03 20:40:11 +03:00
tile = tile or opts . SWIN_tile
tile_overlap = tile_overlap or opts . SWIN_tile_overlap
2022-09-26 09:29:22 -05:00
img = np . array ( img )
img = img [ : , : , : : - 1 ]
img = np . moveaxis ( img , 2 , 0 ) / 255
img = torch . from_numpy ( img ) . float ( )
2022-12-03 18:06:33 +03:00
img = img . unsqueeze ( 0 ) . to ( device_swinir , dtype = devices . dtype )
2022-11-28 21:36:35 -05:00
with torch . no_grad ( ) , devices . autocast ( ) :
2022-09-26 09:29:22 -05:00
_ , _ , h_old , w_old = img . size ( )
h_pad = ( h_old / / window_size + 1 ) * window_size - h_old
w_pad = ( w_old / / window_size + 1 ) * window_size - w_old
img = torch . cat ( [ img , torch . flip ( img , [ 2 ] ) ] , 2 ) [ : , : , : h_old + h_pad , : ]
img = torch . cat ( [ img , torch . flip ( img , [ 3 ] ) ] , 3 ) [ : , : , : , : w_old + w_pad ]
output = inference ( img , model , tile , tile_overlap , window_size , scale )
output = output [ . . . , : h_old * scale , : w_old * scale ]
output = output . data . squeeze ( ) . float ( ) . cpu ( ) . clamp_ ( 0 , 1 ) . numpy ( )
if output . ndim == 3 :
output = np . transpose (
output [ [ 2 , 1 , 0 ] , : , : ] , ( 1 , 2 , 0 )
) # CHW-RGB to HCW-BGR
output = ( output * 255.0 ) . round ( ) . astype ( np . uint8 ) # float32 to uint8
return Image . fromarray ( output , " RGB " )
def inference ( img , model , tile , tile_overlap , window_size , scale ) :
# test the image tile by tile
b , c , h , w = img . size ( )
tile = min ( tile , h , w )
assert tile % window_size == 0 , " tile size should be a multiple of window_size "
sf = scale
stride = tile - tile_overlap
h_idx_list = list ( range ( 0 , h - tile , stride ) ) + [ h - tile ]
w_idx_list = list ( range ( 0 , w - tile , stride ) ) + [ w - tile ]
2022-12-03 18:06:33 +03:00
E = torch . zeros ( b , c , h * sf , w * sf , dtype = devices . dtype , device = device_swinir ) . type_as ( img )
W = torch . zeros_like ( E , dtype = devices . dtype , device = device_swinir )
2022-09-26 09:29:22 -05:00
2022-10-01 14:04:20 -04:00
with tqdm ( total = len ( h_idx_list ) * len ( w_idx_list ) , desc = " SwinIR tiles " ) as pbar :
for h_idx in h_idx_list :
2023-01-23 22:00:27 -05:00
if state . interrupted or state . skipped :
2023-01-23 21:50:59 -05:00
break
2022-10-01 14:04:20 -04:00
for w_idx in w_idx_list :
2023-01-23 22:00:27 -05:00
if state . interrupted or state . skipped :
2023-01-23 21:50:59 -05:00
break
2022-10-01 14:04:20 -04:00
in_patch = img [ . . . , h_idx : h_idx + tile , w_idx : w_idx + tile ]
out_patch = model ( in_patch )
out_patch_mask = torch . ones_like ( out_patch )
E [
. . . , h_idx * sf : ( h_idx + tile ) * sf , w_idx * sf : ( w_idx + tile ) * sf
] . add_ ( out_patch )
W [
. . . , h_idx * sf : ( h_idx + tile ) * sf , w_idx * sf : ( w_idx + tile ) * sf
] . add_ ( out_patch_mask )
pbar . update ( 1 )
2022-09-26 09:29:22 -05:00
output = E . div_ ( W )
return output
2022-12-03 18:06:33 +03:00
def on_ui_settings ( ) :
import gradio as gr
shared . opts . add_option ( " SWIN_tile " , shared . OptionInfo ( 192 , " Tile size for all SwinIR. " , gr . Slider , { " minimum " : 16 , " maximum " : 512 , " step " : 16 } , section = ( ' upscaling ' , " Upscaling " ) ) )
shared . opts . add_option ( " SWIN_tile_overlap " , shared . OptionInfo ( 8 , " Tile overlap, in pixels for SwinIR. Low values = visible seam. " , gr . Slider , { " minimum " : 0 , " maximum " : 48 , " step " : 1 } , section = ( ' upscaling ' , " Upscaling " ) ) )
script_callbacks . on_ui_settings ( on_ui_settings )