Merge pull request #11696 from WuSiYu/feat_SWIN_torch_compile
feat: add option SWIN_torch_compile to accelerate SwinIR upscale
This commit is contained in:
commit
bcb6ad5fab
@ -1,4 +1,5 @@
|
|||||||
import sys
|
import sys
|
||||||
|
import platform
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -18,6 +19,8 @@ device_swinir = devices.get_device_for('swinir')
|
|||||||
|
|
||||||
class UpscalerSwinIR(Upscaler):
|
class UpscalerSwinIR(Upscaler):
|
||||||
def __init__(self, dirname):
|
def __init__(self, dirname):
|
||||||
|
self._cached_model = None # keep the model when SWIN_torch_compile is on to prevent re-compile every runs
|
||||||
|
self._cached_model_config = None # to clear '_cached_model' when changing model (v1/v2) or settings
|
||||||
self.name = "SwinIR"
|
self.name = "SwinIR"
|
||||||
self.model_url = SWINIR_MODEL_URL
|
self.model_url = SWINIR_MODEL_URL
|
||||||
self.model_name = "SwinIR 4x"
|
self.model_name = "SwinIR 4x"
|
||||||
@ -35,12 +38,24 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
self.scalers = scalers
|
self.scalers = scalers
|
||||||
|
|
||||||
def do_upscale(self, img, model_file):
|
def do_upscale(self, img, model_file):
|
||||||
try:
|
use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
|
||||||
model = self.load_model(model_file)
|
and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
|
||||||
except Exception as e:
|
current_config = (model_file, opts.SWIN_tile)
|
||||||
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
|
|
||||||
return img
|
if use_compile and self._cached_model_config == current_config:
|
||||||
model = model.to(device_swinir, dtype=devices.dtype)
|
model = self._cached_model
|
||||||
|
else:
|
||||||
|
self._cached_model = None
|
||||||
|
try:
|
||||||
|
model = self.load_model(model_file)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
|
||||||
|
return img
|
||||||
|
model = model.to(device_swinir, dtype=devices.dtype)
|
||||||
|
if use_compile:
|
||||||
|
model = torch.compile(model)
|
||||||
|
self._cached_model = model
|
||||||
|
self._cached_model_config = current_config
|
||||||
img = upscale(img, model)
|
img = upscale(img, model)
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
return img
|
return img
|
||||||
@ -170,6 +185,8 @@ def on_ui_settings():
|
|||||||
|
|
||||||
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
|
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
|
||||||
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
|
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
|
||||||
|
if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows
|
||||||
|
shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
|
||||||
|
|
||||||
|
|
||||||
script_callbacks.on_ui_settings(on_ui_settings)
|
script_callbacks.on_ui_settings(on_ui_settings)
|
||||||
|
Loading…
Reference in New Issue
Block a user