rework extras tab to use script system
This commit is contained in:
parent
68303c96e5
commit
b5230197a6
@ -104,11 +104,6 @@ function create_tab_index_args(tabId, args){
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
function get_extras_tab_index(){
|
|
||||||
const [,,...args] = [...arguments]
|
|
||||||
return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
|
|
||||||
}
|
|
||||||
|
|
||||||
function get_img2img_tab_index() {
|
function get_img2img_tab_index() {
|
||||||
let res = args_to_array(arguments)
|
let res = args_to_array(arguments)
|
||||||
res.splice(-2)
|
res.splice(-2)
|
||||||
|
@ -11,10 +11,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
|||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
||||||
from modules.api.models import *
|
from modules.api.models import *
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.extras import run_extras
|
|
||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
from modules.textual_inversion.preprocess import preprocess
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
@ -45,10 +44,8 @@ def validate_sampler_name(name):
|
|||||||
|
|
||||||
def setUpscalers(req: dict):
|
def setUpscalers(req: dict):
|
||||||
reqDict = vars(req)
|
reqDict = vars(req)
|
||||||
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
|
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
||||||
reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
|
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
||||||
reqDict.pop('upscaler_1')
|
|
||||||
reqDict.pop('upscaler_2')
|
|
||||||
return reqDict
|
return reqDict
|
||||||
|
|
||||||
def decode_base64_to_image(encoding):
|
def decode_base64_to_image(encoding):
|
||||||
@ -244,7 +241,7 @@ class Api:
|
|||||||
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
||||||
|
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||||
|
|
||||||
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||||
|
|
||||||
@ -260,7 +257,7 @@ class Api:
|
|||||||
reqDict.pop('imageList')
|
reqDict.pop('imageList')
|
||||||
|
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||||
|
|
||||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||||
|
|
||||||
|
@ -1,219 +1,103 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from typing import Callable, List, OrderedDict, Tuple
|
from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
|
||||||
from functools import partial
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from modules import shared, images, devices, ui_components
|
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
import modules.gfpgan_model
|
|
||||||
import modules.codeformer_model
|
|
||||||
|
|
||||||
|
|
||||||
class LruCache(OrderedDict):
|
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
|
||||||
@dataclass(frozen=True)
|
|
||||||
class Key:
|
|
||||||
image_hash: int
|
|
||||||
info_hash: int
|
|
||||||
args_hash: int
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Value:
|
|
||||||
image: Image.Image
|
|
||||||
info: str
|
|
||||||
|
|
||||||
def __init__(self, max_size: int = 5, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._max_size = max_size
|
|
||||||
|
|
||||||
def get(self, key: LruCache.Key) -> LruCache.Value:
|
|
||||||
ret = super().get(key)
|
|
||||||
if ret is not None:
|
|
||||||
self.move_to_end(key) # Move to end of eviction list
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
|
|
||||||
self[key] = value
|
|
||||||
while len(self) > self._max_size:
|
|
||||||
self.popitem(last=False)
|
|
||||||
|
|
||||||
|
|
||||||
cached_images: LruCache = LruCache(max_size=5)
|
|
||||||
|
|
||||||
|
|
||||||
def run_postprocessing(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
shared.state.job = 'extras'
|
shared.state.job = 'extras'
|
||||||
|
|
||||||
imageArr = []
|
image_data = []
|
||||||
# Also keep track of original file names
|
image_names = []
|
||||||
imageNameArr = []
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
if extras_mode == 1:
|
if extras_mode == 1:
|
||||||
#convert file to pillow image
|
|
||||||
for img in image_folder:
|
for img in image_folder:
|
||||||
image = Image.open(img)
|
image = Image.open(img)
|
||||||
imageArr.append(image)
|
image_data.append(image)
|
||||||
imageNameArr.append(os.path.splitext(img.orig_name)[0])
|
image_names.append(os.path.splitext(img.orig_name)[0])
|
||||||
elif extras_mode == 2:
|
elif extras_mode == 2:
|
||||||
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||||
|
assert input_dir, 'input directory not selected'
|
||||||
|
|
||||||
if input_dir == '':
|
|
||||||
return outputs, "Please select an input directory.", ''
|
|
||||||
image_list = shared.listfiles(input_dir)
|
image_list = shared.listfiles(input_dir)
|
||||||
for img in image_list:
|
for filename in image_list:
|
||||||
try:
|
try:
|
||||||
image = Image.open(img)
|
image = Image.open(filename)
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
imageArr.append(image)
|
image_data.append(image)
|
||||||
imageNameArr.append(img)
|
image_names.append(filename)
|
||||||
else:
|
else:
|
||||||
imageArr.append(image)
|
assert image, 'image not selected'
|
||||||
imageNameArr.append(None)
|
|
||||||
|
image_data.append(image)
|
||||||
|
image_names.append(None)
|
||||||
|
|
||||||
if extras_mode == 2 and output_dir != '':
|
if extras_mode == 2 and output_dir != '':
|
||||||
outpath = output_dir
|
outpath = output_dir
|
||||||
else:
|
else:
|
||||||
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
||||||
|
|
||||||
# Extra operation definitions
|
infotext = ''
|
||||||
|
|
||||||
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
for image, name in zip(image_data, image_names):
|
||||||
shared.state.job = 'extras-gfpgan'
|
shared.state.textinfo = name
|
||||||
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
|
|
||||||
res = Image.fromarray(restored_img)
|
|
||||||
|
|
||||||
if gfpgan_visibility < 1.0:
|
|
||||||
res = Image.blend(image, res, gfpgan_visibility)
|
|
||||||
|
|
||||||
info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
|
|
||||||
return (res, info)
|
|
||||||
|
|
||||||
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
|
||||||
shared.state.job = 'extras-codeformer'
|
|
||||||
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
|
|
||||||
res = Image.fromarray(restored_img)
|
|
||||||
|
|
||||||
if codeformer_visibility < 1.0:
|
|
||||||
res = Image.blend(image, res, codeformer_visibility)
|
|
||||||
|
|
||||||
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
|
||||||
return (res, info)
|
|
||||||
|
|
||||||
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
|
||||||
shared.state.job = 'extras-upscale'
|
|
||||||
upscaler = shared.sd_upscalers[scaler_index]
|
|
||||||
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
|
||||||
if mode == 1 and crop:
|
|
||||||
cropped = Image.new("RGB", (resize_w, resize_h))
|
|
||||||
cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
|
|
||||||
res = cropped
|
|
||||||
return res
|
|
||||||
|
|
||||||
def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
|
||||||
# Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
|
|
||||||
nonlocal upscaling_resize
|
|
||||||
if resize_mode == 1:
|
|
||||||
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
|
|
||||||
crop_info = " (crop)" if upscaling_crop else ""
|
|
||||||
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
|
|
||||||
return (image, info)
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class UpscaleParams:
|
|
||||||
upscaler_idx: int
|
|
||||||
blend_alpha: float
|
|
||||||
|
|
||||||
def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
|
||||||
blended_result: Image.Image = None
|
|
||||||
image_hash: str = hash(np.array(image.getdata()).tobytes())
|
|
||||||
for upscaler in params:
|
|
||||||
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
|
|
||||||
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
|
||||||
cache_key = LruCache.Key(image_hash=image_hash,
|
|
||||||
info_hash=hash(info),
|
|
||||||
args_hash=hash(upscale_args))
|
|
||||||
cached_entry = cached_images.get(cache_key)
|
|
||||||
if cached_entry is None:
|
|
||||||
res = upscale(image, *upscale_args)
|
|
||||||
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
|
|
||||||
cached_images.put(cache_key, LruCache.Value(image=res, info=info))
|
|
||||||
else:
|
|
||||||
res, info = cached_entry.image, cached_entry.info
|
|
||||||
|
|
||||||
if blended_result is None:
|
|
||||||
blended_result = res
|
|
||||||
else:
|
|
||||||
blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
|
|
||||||
return (blended_result, info)
|
|
||||||
|
|
||||||
# Build a list of operations to run
|
|
||||||
facefix_ops: List[Callable] = []
|
|
||||||
facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
|
|
||||||
facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []
|
|
||||||
|
|
||||||
upscale_ops: List[Callable] = []
|
|
||||||
upscale_ops += [run_prepare_crop] if resize_mode == 1 else []
|
|
||||||
|
|
||||||
if upscaling_resize != 0:
|
|
||||||
step_params: List[UpscaleParams] = []
|
|
||||||
step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
|
|
||||||
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
|
||||||
step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))
|
|
||||||
|
|
||||||
upscale_ops.append(partial(run_upscalers_blend, step_params))
|
|
||||||
|
|
||||||
extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
|
|
||||||
|
|
||||||
for image, image_name in zip(imageArr, imageNameArr):
|
|
||||||
if image is None:
|
|
||||||
return outputs, "Please select an input image.", ''
|
|
||||||
|
|
||||||
shared.state.textinfo = f'Processing image {image_name}'
|
|
||||||
|
|
||||||
existing_pnginfo = image.info or {}
|
existing_pnginfo = image.info or {}
|
||||||
|
|
||||||
image = image.convert("RGB")
|
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
|
||||||
info = ""
|
|
||||||
# Run each operation on each image
|
|
||||||
for op in extras_ops:
|
|
||||||
image, info = op(image, info)
|
|
||||||
|
|
||||||
if opts.use_original_name_batch and image_name is not None:
|
scripts.scripts_postproc.run(pp, args)
|
||||||
basename = os.path.splitext(os.path.basename(image_name))[0]
|
|
||||||
|
if opts.use_original_name_batch and name is not None:
|
||||||
|
basename = os.path.splitext(os.path.basename(name))[0]
|
||||||
else:
|
else:
|
||||||
basename = ''
|
basename = ''
|
||||||
|
|
||||||
if opts.enable_pnginfo: # append info before save
|
infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
|
||||||
image.info = existing_pnginfo
|
|
||||||
image.info["extras"] = info
|
if opts.enable_pnginfo:
|
||||||
|
pp.image.info = existing_pnginfo
|
||||||
|
pp.image.info["postprocessing"] = infotext
|
||||||
|
|
||||||
if save_output:
|
if save_output:
|
||||||
# Add upscaler name as a suffix.
|
images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=pp.info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
|
||||||
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
|
|
||||||
# Add second upscaler if applicable.
|
|
||||||
if suffix and extras_upscaler_2 and extras_upscaler_2_visibility:
|
|
||||||
suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}"
|
|
||||||
|
|
||||||
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
|
||||||
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
|
|
||||||
|
|
||||||
if extras_mode != 2 or show_extras_results:
|
if extras_mode != 2 or show_extras_results:
|
||||||
outputs.append(image)
|
outputs.append(pp.image)
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
return outputs, ui_components.plaintext_to_html(info), ''
|
return outputs, ui_common.plaintext_to_html(infotext), ''
|
||||||
|
|
||||||
|
|
||||||
def clear_cache():
|
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
||||||
cached_images.clear()
|
"""old handler for API"""
|
||||||
|
|
||||||
|
args = scripts.scripts_postproc.create_args_for_run({
|
||||||
|
"Upscale": {
|
||||||
|
"upscale_mode": resize_mode,
|
||||||
|
"upscale_by": upscaling_resize,
|
||||||
|
"upscale_to_width": upscaling_resize_w,
|
||||||
|
"upscale_to_height": upscaling_resize_h,
|
||||||
|
"upscale_crop": upscaling_crop,
|
||||||
|
"upscaler_1_name": extras_upscaler_1,
|
||||||
|
"upscaler_2_name": extras_upscaler_2,
|
||||||
|
"upscaler_2_visibility": extras_upscaler_2_visibility,
|
||||||
|
},
|
||||||
|
"GFPGAN": {
|
||||||
|
"gfpgan_visibility": gfpgan_visibility,
|
||||||
|
},
|
||||||
|
"CodeFormer": {
|
||||||
|
"codeformer_visibility": codeformer_visibility,
|
||||||
|
"codeformer_weight": codeformer_weight,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
|
||||||
|
@ -7,7 +7,7 @@ from collections import namedtuple
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules.processing import StableDiffusionProcessing
|
from modules.processing import StableDiffusionProcessing
|
||||||
from modules import shared, paths, script_callbacks, extensions, script_loading
|
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
|
||||||
|
|
||||||
AlwaysVisible = object()
|
AlwaysVisible = object()
|
||||||
|
|
||||||
@ -150,8 +150,10 @@ def basedir():
|
|||||||
return current_basedir
|
return current_basedir
|
||||||
|
|
||||||
|
|
||||||
scripts_data = []
|
|
||||||
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
|
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
|
||||||
|
|
||||||
|
scripts_data = []
|
||||||
|
postprocessing_scripts_data = []
|
||||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||||
|
|
||||||
|
|
||||||
@ -190,23 +192,31 @@ def list_files_with_name(filename):
|
|||||||
def load_scripts():
|
def load_scripts():
|
||||||
global current_basedir
|
global current_basedir
|
||||||
scripts_data.clear()
|
scripts_data.clear()
|
||||||
|
postprocessing_scripts_data.clear()
|
||||||
script_callbacks.clear_callbacks()
|
script_callbacks.clear_callbacks()
|
||||||
|
|
||||||
scripts_list = list_scripts("scripts", ".py")
|
scripts_list = list_scripts("scripts", ".py")
|
||||||
|
|
||||||
syspath = sys.path
|
syspath = sys.path
|
||||||
|
|
||||||
|
def register_scripts_from_module(module):
|
||||||
|
for key, script_class in module.__dict__.items():
|
||||||
|
if type(script_class) != type:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if issubclass(script_class, Script):
|
||||||
|
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||||
|
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
||||||
|
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||||
|
|
||||||
for scriptfile in sorted(scripts_list):
|
for scriptfile in sorted(scripts_list):
|
||||||
try:
|
try:
|
||||||
if scriptfile.basedir != paths.script_path:
|
if scriptfile.basedir != paths.script_path:
|
||||||
sys.path = [scriptfile.basedir] + sys.path
|
sys.path = [scriptfile.basedir] + sys.path
|
||||||
current_basedir = scriptfile.basedir
|
current_basedir = scriptfile.basedir
|
||||||
|
|
||||||
module = script_loading.load_module(scriptfile.path)
|
script_module = script_loading.load_module(scriptfile.path)
|
||||||
|
register_scripts_from_module(script_module)
|
||||||
for key, script_class in module.__dict__.items():
|
|
||||||
if type(script_class) == type and issubclass(script_class, Script):
|
|
||||||
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
|
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
|
||||||
@ -413,6 +423,7 @@ class ScriptRunner:
|
|||||||
|
|
||||||
scripts_txt2img = ScriptRunner()
|
scripts_txt2img = ScriptRunner()
|
||||||
scripts_img2img = ScriptRunner()
|
scripts_img2img = ScriptRunner()
|
||||||
|
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
||||||
scripts_current: ScriptRunner = None
|
scripts_current: ScriptRunner = None
|
||||||
|
|
||||||
|
|
||||||
@ -423,12 +434,13 @@ def reload_script_body_only():
|
|||||||
|
|
||||||
|
|
||||||
def reload_scripts():
|
def reload_scripts():
|
||||||
global scripts_txt2img, scripts_img2img
|
global scripts_txt2img, scripts_img2img, scripts_postproc
|
||||||
|
|
||||||
load_scripts()
|
load_scripts()
|
||||||
|
|
||||||
scripts_txt2img = ScriptRunner()
|
scripts_txt2img = ScriptRunner()
|
||||||
scripts_img2img = ScriptRunner()
|
scripts_img2img = ScriptRunner()
|
||||||
|
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
||||||
|
|
||||||
|
|
||||||
def IOComponent_init(self, *args, **kwargs):
|
def IOComponent_init(self, *args, **kwargs):
|
||||||
|
147
modules/scripts_postprocessing.py
Normal file
147
modules/scripts_postprocessing.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
import os
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import errors, shared
|
||||||
|
|
||||||
|
|
||||||
|
class PostprocessedImage:
|
||||||
|
def __init__(self, image):
|
||||||
|
self.image = image
|
||||||
|
self.info = {}
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptPostprocessing:
|
||||||
|
filename = None
|
||||||
|
controls = None
|
||||||
|
args_from = None
|
||||||
|
args_to = None
|
||||||
|
|
||||||
|
order = 1000
|
||||||
|
"""scripts will be ordred by this value in postprocessing UI"""
|
||||||
|
|
||||||
|
name = None
|
||||||
|
"""this function should return the title of the script."""
|
||||||
|
|
||||||
|
group = None
|
||||||
|
"""A gr.Group component that has all script's UI inside it"""
|
||||||
|
|
||||||
|
def ui(self):
|
||||||
|
"""
|
||||||
|
This function should create gradio UI elements. See https://gradio.app/docs/#components
|
||||||
|
The return value should be a dictionary that maps parameter names to components used in processing.
|
||||||
|
Values of those components will be passed to process() function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
def process(self, pp: PostprocessedImage, **args):
|
||||||
|
"""
|
||||||
|
This function is called to postprocess the image.
|
||||||
|
args contains a dictionary with all values returned by components from ui()
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
def image_changed(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||||
|
try:
|
||||||
|
res = func(*args, **kwargs)
|
||||||
|
return res
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"calling {filename}/{funcname}")
|
||||||
|
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptPostprocessingRunner:
|
||||||
|
def __init__(self):
|
||||||
|
self.scripts = None
|
||||||
|
self.ui_created = False
|
||||||
|
|
||||||
|
def initialize_scripts(self, scripts_data):
|
||||||
|
self.scripts = []
|
||||||
|
|
||||||
|
for script_class, path, basedir, script_module in scripts_data:
|
||||||
|
script: ScriptPostprocessing = script_class()
|
||||||
|
script.filename = path
|
||||||
|
|
||||||
|
self.scripts.append(script)
|
||||||
|
|
||||||
|
def create_script_ui(self, script, inputs):
|
||||||
|
script.args_from = len(inputs)
|
||||||
|
script.args_to = len(inputs)
|
||||||
|
|
||||||
|
script.controls = wrap_call(script.ui, script.filename, "ui")
|
||||||
|
|
||||||
|
for control in script.controls.values():
|
||||||
|
control.custom_script_source = os.path.basename(script.filename)
|
||||||
|
|
||||||
|
inputs += list(script.controls.values())
|
||||||
|
script.args_to = len(inputs)
|
||||||
|
|
||||||
|
def scripts_in_preferred_order(self):
|
||||||
|
if self.scripts is None:
|
||||||
|
import modules.scripts
|
||||||
|
self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
|
||||||
|
|
||||||
|
scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")]
|
||||||
|
|
||||||
|
def script_score(name):
|
||||||
|
name = name.lower()
|
||||||
|
for i, possible_match in enumerate(scripts_order):
|
||||||
|
if possible_match in name:
|
||||||
|
return i
|
||||||
|
|
||||||
|
return len(self.scripts)
|
||||||
|
|
||||||
|
script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)}
|
||||||
|
|
||||||
|
return sorted(self.scripts, key=lambda x: script_scores[x.name])
|
||||||
|
|
||||||
|
def setup_ui(self):
|
||||||
|
inputs = []
|
||||||
|
|
||||||
|
for script in self.scripts_in_preferred_order():
|
||||||
|
with gr.Box() as group:
|
||||||
|
self.create_script_ui(script, inputs)
|
||||||
|
|
||||||
|
script.group = group
|
||||||
|
|
||||||
|
self.ui_created = True
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def run(self, pp: PostprocessedImage, args):
|
||||||
|
for script in self.scripts_in_preferred_order():
|
||||||
|
shared.state.job = script.name
|
||||||
|
|
||||||
|
script_args = args[script.args_from:script.args_to]
|
||||||
|
|
||||||
|
process_args = {}
|
||||||
|
for (name, component), value in zip(script.controls.items(), script_args):
|
||||||
|
process_args[name] = value
|
||||||
|
|
||||||
|
script.process(pp, **process_args)
|
||||||
|
|
||||||
|
def create_args_for_run(self, scripts_args):
|
||||||
|
if not self.ui_created:
|
||||||
|
with gr.Blocks(analytics_enabled=False):
|
||||||
|
self.setup_ui()
|
||||||
|
|
||||||
|
scripts = self.scripts_in_preferred_order()
|
||||||
|
args = [None] * max([x.args_to for x in scripts])
|
||||||
|
|
||||||
|
for script in scripts:
|
||||||
|
script_args_dict = scripts_args.get(script.name, None)
|
||||||
|
if script_args_dict is not None:
|
||||||
|
|
||||||
|
for i, name in enumerate(script.controls):
|
||||||
|
args[script.args_from + i] = script_args_dict.get(name, None)
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
def image_changed(self):
|
||||||
|
for script in self.scripts_in_preferred_order():
|
||||||
|
script.image_changed()
|
@ -474,6 +474,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
|
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
||||||
|
'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"),
|
||||||
|
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section((None, "Hidden options"), {
|
options_templates.update(options_section((None, "Hidden options"), {
|
||||||
"disabled_extensions": OptionInfo([], "Disable those extensions"),
|
"disabled_extensions": OptionInfo([], "Disable those extensions"),
|
||||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||||
|
265
modules/ui.py
265
modules/ui.py
@ -20,7 +20,7 @@ import numpy as np
|
|||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components
|
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
|
||||||
@ -86,7 +86,6 @@ css_hide_progressbar = """
|
|||||||
random_symbol = '\U0001f3b2\ufe0f' # 🎲️
|
random_symbol = '\U0001f3b2\ufe0f' # 🎲️
|
||||||
reuse_symbol = '\u267b\ufe0f' # ♻️
|
reuse_symbol = '\u267b\ufe0f' # ♻️
|
||||||
paste_symbol = '\u2199\ufe0f' # ↙
|
paste_symbol = '\u2199\ufe0f' # ↙
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
save_style_symbol = '\U0001f4be' # 💾
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
apply_style_symbol = '\U0001f4cb' # 📋
|
apply_style_symbol = '\U0001f4cb' # 📋
|
||||||
@ -95,7 +94,7 @@ extra_networks_symbol = '\U0001F3B4' # 🎴
|
|||||||
|
|
||||||
|
|
||||||
def plaintext_to_html(text):
|
def plaintext_to_html(text):
|
||||||
return ui_components.plaintext_to_html(text)
|
return ui_common.plaintext_to_html(text)
|
||||||
|
|
||||||
|
|
||||||
def send_gradio_gallery_to_image(x):
|
def send_gradio_gallery_to_image(x):
|
||||||
@ -103,70 +102,6 @@ def send_gradio_gallery_to_image(x):
|
|||||||
return None
|
return None
|
||||||
return image_from_url_text(x[0])
|
return image_from_url_text(x[0])
|
||||||
|
|
||||||
def save_files(js_data, images, do_make_zip, index):
|
|
||||||
import csv
|
|
||||||
filenames = []
|
|
||||||
fullfns = []
|
|
||||||
|
|
||||||
#quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
|
|
||||||
class MyObject:
|
|
||||||
def __init__(self, d=None):
|
|
||||||
if d is not None:
|
|
||||||
for key, value in d.items():
|
|
||||||
setattr(self, key, value)
|
|
||||||
|
|
||||||
data = json.loads(js_data)
|
|
||||||
|
|
||||||
p = MyObject(data)
|
|
||||||
path = opts.outdir_save
|
|
||||||
save_to_dirs = opts.use_save_to_dirs_for_ui
|
|
||||||
extension: str = opts.samples_format
|
|
||||||
start_index = 0
|
|
||||||
|
|
||||||
if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
|
||||||
|
|
||||||
images = [images[index]]
|
|
||||||
start_index = index
|
|
||||||
|
|
||||||
os.makedirs(opts.outdir_save, exist_ok=True)
|
|
||||||
|
|
||||||
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
|
|
||||||
at_start = file.tell() == 0
|
|
||||||
writer = csv.writer(file)
|
|
||||||
if at_start:
|
|
||||||
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
|
||||||
|
|
||||||
for image_index, filedata in enumerate(images, start_index):
|
|
||||||
image = image_from_url_text(filedata)
|
|
||||||
|
|
||||||
is_grid = image_index < p.index_of_first_image
|
|
||||||
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
|
||||||
|
|
||||||
fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
|
||||||
|
|
||||||
filename = os.path.relpath(fullfn, path)
|
|
||||||
filenames.append(filename)
|
|
||||||
fullfns.append(fullfn)
|
|
||||||
if txt_fullfn:
|
|
||||||
filenames.append(os.path.basename(txt_fullfn))
|
|
||||||
fullfns.append(txt_fullfn)
|
|
||||||
|
|
||||||
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
|
||||||
|
|
||||||
# Make Zip
|
|
||||||
if do_make_zip:
|
|
||||||
zip_filepath = os.path.join(path, "images.zip")
|
|
||||||
|
|
||||||
from zipfile import ZipFile
|
|
||||||
with ZipFile(zip_filepath, "w") as zip_file:
|
|
||||||
for i in range(len(fullfns)):
|
|
||||||
with open(fullfns[i], mode="rb") as f:
|
|
||||||
zip_file.writestr(filenames[i], f.read())
|
|
||||||
fullfns.insert(0, zip_filepath)
|
|
||||||
|
|
||||||
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
|
||||||
|
|
||||||
|
|
||||||
def visit(x, func, path=""):
|
def visit(x, func, path=""):
|
||||||
if hasattr(x, 'children'):
|
if hasattr(x, 'children'):
|
||||||
for c in x.children:
|
for c in x.children:
|
||||||
@ -444,19 +379,6 @@ def apply_setting(key, value):
|
|||||||
opts.save(shared.config_filename)
|
opts.save(shared.config_filename)
|
||||||
return getattr(opts, key)
|
return getattr(opts, key)
|
||||||
|
|
||||||
|
|
||||||
def update_generation_info(generation_info, html_info, img_index):
|
|
||||||
try:
|
|
||||||
generation_info = json.loads(generation_info)
|
|
||||||
if img_index < 0 or img_index >= len(generation_info["infotexts"]):
|
|
||||||
return html_info, gr.update()
|
|
||||||
return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
# if the json parse or anything else fails, just return the old html_info
|
|
||||||
return html_info, gr.update()
|
|
||||||
|
|
||||||
|
|
||||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||||
def refresh():
|
def refresh():
|
||||||
refresh_method()
|
refresh_method()
|
||||||
@ -477,107 +399,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
|
|||||||
|
|
||||||
|
|
||||||
def create_output_panel(tabname, outdir):
|
def create_output_panel(tabname, outdir):
|
||||||
def open_folder(f):
|
return ui_common.create_output_panel(tabname, outdir)
|
||||||
if not os.path.exists(f):
|
|
||||||
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
|
|
||||||
return
|
|
||||||
elif not os.path.isdir(f):
|
|
||||||
print(f"""
|
|
||||||
WARNING
|
|
||||||
An open_folder request was made with an argument that is not a folder.
|
|
||||||
This could be an error or a malicious attempt to run code on your computer.
|
|
||||||
Requested path was: {f}
|
|
||||||
""", file=sys.stderr)
|
|
||||||
return
|
|
||||||
|
|
||||||
if not shared.cmd_opts.hide_ui_dir_config:
|
|
||||||
path = os.path.normpath(f)
|
|
||||||
if platform.system() == "Windows":
|
|
||||||
os.startfile(path)
|
|
||||||
elif platform.system() == "Darwin":
|
|
||||||
sp.Popen(["open", path])
|
|
||||||
elif "microsoft-standard-WSL2" in platform.uname().release:
|
|
||||||
sp.Popen(["wsl-open", path])
|
|
||||||
else:
|
|
||||||
sp.Popen(["xdg-open", path])
|
|
||||||
|
|
||||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
|
||||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
|
||||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
|
|
||||||
|
|
||||||
generation_info = None
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Row(elem_id=f"image_buttons_{tabname}"):
|
|
||||||
open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
|
|
||||||
|
|
||||||
if tabname != "extras":
|
|
||||||
save = gr.Button('Save', elem_id=f'save_{tabname}')
|
|
||||||
save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
|
|
||||||
|
|
||||||
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
|
|
||||||
|
|
||||||
open_folder_button.click(
|
|
||||||
fn=lambda: open_folder(opts.outdir_samples or outdir),
|
|
||||||
inputs=[],
|
|
||||||
outputs=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
if tabname != "extras":
|
|
||||||
with gr.Row():
|
|
||||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
|
|
||||||
|
|
||||||
with gr.Group():
|
|
||||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
|
||||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
|
||||||
|
|
||||||
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
|
|
||||||
if tabname == 'txt2img' or tabname == 'img2img':
|
|
||||||
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
|
|
||||||
generation_info_button.click(
|
|
||||||
fn=update_generation_info,
|
|
||||||
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
|
|
||||||
inputs=[generation_info, html_info, html_info],
|
|
||||||
outputs=[html_info, html_info],
|
|
||||||
)
|
|
||||||
|
|
||||||
save.click(
|
|
||||||
fn=wrap_gradio_call(save_files),
|
|
||||||
_js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
|
|
||||||
inputs=[
|
|
||||||
generation_info,
|
|
||||||
result_gallery,
|
|
||||||
html_info,
|
|
||||||
html_info,
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
download_files,
|
|
||||||
html_log,
|
|
||||||
],
|
|
||||||
show_progress=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
save_zip.click(
|
|
||||||
fn=wrap_gradio_call(save_files),
|
|
||||||
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
|
|
||||||
inputs=[
|
|
||||||
generation_info,
|
|
||||||
result_gallery,
|
|
||||||
html_info,
|
|
||||||
html_info,
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
download_files,
|
|
||||||
html_log,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
|
|
||||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
|
||||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
|
||||||
|
|
||||||
parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
|
|
||||||
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
|
||||||
|
|
||||||
|
|
||||||
def create_sampler_and_steps_selection(choices, tabname):
|
def create_sampler_and_steps_selection(choices, tabname):
|
||||||
@ -1106,86 +928,7 @@ def create_ui():
|
|||||||
modules.scripts.scripts_current = None
|
modules.scripts.scripts_current = None
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
||||||
with gr.Row().style(equal_height=False):
|
ui_postprocessing.create_ui()
|
||||||
with gr.Column(variant='compact'):
|
|
||||||
with gr.Tabs(elem_id="mode_extras"):
|
|
||||||
with gr.TabItem('Single Image', elem_id="extras_single_tab"):
|
|
||||||
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
|
|
||||||
|
|
||||||
with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"):
|
|
||||||
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
|
|
||||||
|
|
||||||
with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"):
|
|
||||||
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
|
|
||||||
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
|
|
||||||
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
|
|
||||||
|
|
||||||
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
|
|
||||||
|
|
||||||
with gr.Tabs(elem_id="extras_resize_mode"):
|
|
||||||
with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"):
|
|
||||||
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
|
|
||||||
with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"):
|
|
||||||
with gr.Group():
|
|
||||||
with gr.Row():
|
|
||||||
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
|
|
||||||
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
|
|
||||||
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
|
|
||||||
|
|
||||||
with gr.Group():
|
|
||||||
extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
|
||||||
|
|
||||||
with gr.Group():
|
|
||||||
extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
|
||||||
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility")
|
|
||||||
|
|
||||||
with gr.Group():
|
|
||||||
gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility")
|
|
||||||
|
|
||||||
with gr.Group():
|
|
||||||
codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility")
|
|
||||||
codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight")
|
|
||||||
|
|
||||||
with gr.Group():
|
|
||||||
upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix")
|
|
||||||
|
|
||||||
result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples)
|
|
||||||
|
|
||||||
submit.click(
|
|
||||||
fn=wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
|
|
||||||
_js="get_extras_tab_index",
|
|
||||||
inputs=[
|
|
||||||
dummy_component,
|
|
||||||
dummy_component,
|
|
||||||
extras_image,
|
|
||||||
image_batch,
|
|
||||||
extras_batch_input_dir,
|
|
||||||
extras_batch_output_dir,
|
|
||||||
show_extras_results,
|
|
||||||
gfpgan_visibility,
|
|
||||||
codeformer_visibility,
|
|
||||||
codeformer_weight,
|
|
||||||
upscaling_resize,
|
|
||||||
upscaling_resize_w,
|
|
||||||
upscaling_resize_h,
|
|
||||||
upscaling_crop,
|
|
||||||
extras_upscaler_1,
|
|
||||||
extras_upscaler_2,
|
|
||||||
extras_upscaler_2_visibility,
|
|
||||||
upscale_before_face_fix,
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
result_images,
|
|
||||||
html_info_x,
|
|
||||||
html_info,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
parameters_copypaste.add_paste_fields("extras", extras_image, None)
|
|
||||||
|
|
||||||
extras_image.change(
|
|
||||||
fn=postprocessing.clear_cache,
|
|
||||||
inputs=[], outputs=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
|
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
|
||||||
with gr.Row().style(equal_height=False):
|
with gr.Row().style(equal_height=False):
|
||||||
|
202
modules/ui_common.py
Normal file
202
modules/ui_common.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
import json
|
||||||
|
import html
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import scipy as sp
|
||||||
|
|
||||||
|
from modules import call_queue, shared
|
||||||
|
from modules.generation_parameters_copypaste import image_from_url_text
|
||||||
|
import modules.images
|
||||||
|
|
||||||
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
|
|
||||||
|
|
||||||
|
def update_generation_info(generation_info, html_info, img_index):
|
||||||
|
try:
|
||||||
|
generation_info = json.loads(generation_info)
|
||||||
|
if img_index < 0 or img_index >= len(generation_info["infotexts"]):
|
||||||
|
return html_info, gr.update()
|
||||||
|
return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# if the json parse or anything else fails, just return the old html_info
|
||||||
|
return html_info, gr.update()
|
||||||
|
|
||||||
|
|
||||||
|
def plaintext_to_html(text):
|
||||||
|
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def save_files(js_data, images, do_make_zip, index):
|
||||||
|
import csv
|
||||||
|
filenames = []
|
||||||
|
fullfns = []
|
||||||
|
|
||||||
|
#quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
|
||||||
|
class MyObject:
|
||||||
|
def __init__(self, d=None):
|
||||||
|
if d is not None:
|
||||||
|
for key, value in d.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
data = json.loads(js_data)
|
||||||
|
|
||||||
|
p = MyObject(data)
|
||||||
|
path = shared.opts.outdir_save
|
||||||
|
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
|
||||||
|
extension: str = shared.opts.samples_format
|
||||||
|
start_index = 0
|
||||||
|
|
||||||
|
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
||||||
|
|
||||||
|
images = [images[index]]
|
||||||
|
start_index = index
|
||||||
|
|
||||||
|
os.makedirs(shared.opts.outdir_save, exist_ok=True)
|
||||||
|
|
||||||
|
with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
|
||||||
|
at_start = file.tell() == 0
|
||||||
|
writer = csv.writer(file)
|
||||||
|
if at_start:
|
||||||
|
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
||||||
|
|
||||||
|
for image_index, filedata in enumerate(images, start_index):
|
||||||
|
image = image_from_url_text(filedata)
|
||||||
|
|
||||||
|
is_grid = image_index < p.index_of_first_image
|
||||||
|
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
||||||
|
|
||||||
|
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
||||||
|
|
||||||
|
filename = os.path.relpath(fullfn, path)
|
||||||
|
filenames.append(filename)
|
||||||
|
fullfns.append(fullfn)
|
||||||
|
if txt_fullfn:
|
||||||
|
filenames.append(os.path.basename(txt_fullfn))
|
||||||
|
fullfns.append(txt_fullfn)
|
||||||
|
|
||||||
|
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
||||||
|
|
||||||
|
# Make Zip
|
||||||
|
if do_make_zip:
|
||||||
|
zip_filepath = os.path.join(path, "images.zip")
|
||||||
|
|
||||||
|
from zipfile import ZipFile
|
||||||
|
with ZipFile(zip_filepath, "w") as zip_file:
|
||||||
|
for i in range(len(fullfns)):
|
||||||
|
with open(fullfns[i], mode="rb") as f:
|
||||||
|
zip_file.writestr(filenames[i], f.read())
|
||||||
|
fullfns.insert(0, zip_filepath)
|
||||||
|
|
||||||
|
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
||||||
|
|
||||||
|
|
||||||
|
def create_output_panel(tabname, outdir):
|
||||||
|
from modules import shared
|
||||||
|
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||||
|
|
||||||
|
def open_folder(f):
|
||||||
|
if not os.path.exists(f):
|
||||||
|
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
|
||||||
|
return
|
||||||
|
elif not os.path.isdir(f):
|
||||||
|
print(f"""
|
||||||
|
WARNING
|
||||||
|
An open_folder request was made with an argument that is not a folder.
|
||||||
|
This could be an error or a malicious attempt to run code on your computer.
|
||||||
|
Requested path was: {f}
|
||||||
|
""", file=sys.stderr)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not shared.cmd_opts.hide_ui_dir_config:
|
||||||
|
path = os.path.normpath(f)
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
os.startfile(path)
|
||||||
|
elif platform.system() == "Darwin":
|
||||||
|
sp.Popen(["open", path])
|
||||||
|
elif "microsoft-standard-WSL2" in platform.uname().release:
|
||||||
|
sp.Popen(["wsl-open", path])
|
||||||
|
else:
|
||||||
|
sp.Popen(["xdg-open", path])
|
||||||
|
|
||||||
|
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
||||||
|
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||||
|
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
|
||||||
|
|
||||||
|
generation_info = None
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row(elem_id=f"image_buttons_{tabname}"):
|
||||||
|
open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
|
||||||
|
|
||||||
|
if tabname != "extras":
|
||||||
|
save = gr.Button('Save', elem_id=f'save_{tabname}')
|
||||||
|
save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
|
||||||
|
|
||||||
|
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
|
||||||
|
|
||||||
|
open_folder_button.click(
|
||||||
|
fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
if tabname != "extras":
|
||||||
|
with gr.Row():
|
||||||
|
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
||||||
|
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||||
|
|
||||||
|
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
|
||||||
|
if tabname == 'txt2img' or tabname == 'img2img':
|
||||||
|
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
|
||||||
|
generation_info_button.click(
|
||||||
|
fn=update_generation_info,
|
||||||
|
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
|
||||||
|
inputs=[generation_info, html_info, html_info],
|
||||||
|
outputs=[html_info, html_info],
|
||||||
|
)
|
||||||
|
|
||||||
|
save.click(
|
||||||
|
fn=call_queue.wrap_gradio_call(save_files),
|
||||||
|
_js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
|
||||||
|
inputs=[
|
||||||
|
generation_info,
|
||||||
|
result_gallery,
|
||||||
|
html_info,
|
||||||
|
html_info,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
download_files,
|
||||||
|
html_log,
|
||||||
|
],
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_zip.click(
|
||||||
|
fn=call_queue.wrap_gradio_call(save_files),
|
||||||
|
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
|
||||||
|
inputs=[
|
||||||
|
generation_info,
|
||||||
|
result_gallery,
|
||||||
|
html_info,
|
||||||
|
html_info,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
download_files,
|
||||||
|
html_log,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
|
||||||
|
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
||||||
|
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||||
|
|
||||||
|
parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
|
||||||
|
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
@ -1,5 +1,3 @@
|
|||||||
import html
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
@ -50,7 +48,3 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
|
|||||||
def get_block_name(self):
|
def get_block_name(self):
|
||||||
return "colorpicker"
|
return "colorpicker"
|
||||||
|
|
||||||
|
|
||||||
def plaintext_to_html(text):
|
|
||||||
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
|
|
||||||
return text
|
|
||||||
|
57
modules/ui_postprocessing.py
Normal file
57
modules/ui_postprocessing.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import gradio as gr
|
||||||
|
from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue
|
||||||
|
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||||
|
|
||||||
|
|
||||||
|
def create_ui():
|
||||||
|
tab_index = gr.State(value=0)
|
||||||
|
|
||||||
|
with gr.Row().style(equal_height=False, variant='compact'):
|
||||||
|
with gr.Column(variant='compact'):
|
||||||
|
with gr.Tabs(elem_id="mode_extras"):
|
||||||
|
with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single:
|
||||||
|
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
|
||||||
|
|
||||||
|
with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
|
||||||
|
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
|
||||||
|
|
||||||
|
with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
|
||||||
|
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
|
||||||
|
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
|
||||||
|
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
|
||||||
|
|
||||||
|
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
|
||||||
|
|
||||||
|
script_inputs = scripts.scripts_postproc.setup_ui()
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
|
||||||
|
|
||||||
|
tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
|
||||||
|
tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
|
||||||
|
tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
|
||||||
|
|
||||||
|
submit.click(
|
||||||
|
fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
|
||||||
|
inputs=[
|
||||||
|
tab_index,
|
||||||
|
extras_image,
|
||||||
|
image_batch,
|
||||||
|
extras_batch_input_dir,
|
||||||
|
extras_batch_output_dir,
|
||||||
|
show_extras_results,
|
||||||
|
*script_inputs
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
result_images,
|
||||||
|
html_info_x,
|
||||||
|
html_info,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
parameters_copypaste.add_paste_fields("extras", extras_image, None)
|
||||||
|
|
||||||
|
extras_image.change(
|
||||||
|
fn=scripts.scripts_postproc.image_changed,
|
||||||
|
inputs=[], outputs=[]
|
||||||
|
)
|
36
scripts/postprocessing_codeformer.py
Normal file
36
scripts/postprocessing_codeformer.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from modules import scripts_postprocessing, codeformer_model
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules.ui_components import FormRow
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing):
|
||||||
|
name = "CodeFormer"
|
||||||
|
order = 3000
|
||||||
|
|
||||||
|
def ui(self):
|
||||||
|
with FormRow():
|
||||||
|
codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility")
|
||||||
|
codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"codeformer_visibility": codeformer_visibility,
|
||||||
|
"codeformer_weight": codeformer_weight,
|
||||||
|
}
|
||||||
|
|
||||||
|
def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight):
|
||||||
|
if codeformer_visibility == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight)
|
||||||
|
res = Image.fromarray(restored_img)
|
||||||
|
|
||||||
|
if codeformer_visibility < 1.0:
|
||||||
|
res = Image.blend(pp.image, res, codeformer_visibility)
|
||||||
|
|
||||||
|
pp.image = res
|
||||||
|
pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3)
|
||||||
|
pp.info["CodeFormer weight"] = round(codeformer_weight, 3)
|
33
scripts/postprocessing_gfpgan.py
Normal file
33
scripts/postprocessing_gfpgan.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from modules import scripts_postprocessing, gfpgan_model
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules.ui_components import FormRow
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
|
||||||
|
name = "GFPGAN"
|
||||||
|
order = 2000
|
||||||
|
|
||||||
|
def ui(self):
|
||||||
|
with FormRow():
|
||||||
|
gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"gfpgan_visibility": gfpgan_visibility,
|
||||||
|
}
|
||||||
|
|
||||||
|
def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility):
|
||||||
|
if gfpgan_visibility == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8))
|
||||||
|
res = Image.fromarray(restored_img)
|
||||||
|
|
||||||
|
if gfpgan_visibility < 1.0:
|
||||||
|
res = Image.blend(pp.image, res, gfpgan_visibility)
|
||||||
|
|
||||||
|
pp.image = res
|
||||||
|
pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3)
|
106
scripts/postprocessing_upscale.py
Normal file
106
scripts/postprocessing_upscale.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from modules import scripts_postprocessing, shared
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules.ui_components import FormRow
|
||||||
|
|
||||||
|
|
||||||
|
upscale_cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
|
||||||
|
name = "Upscale"
|
||||||
|
order = 1000
|
||||||
|
|
||||||
|
def ui(self):
|
||||||
|
selected_tab = gr.State(value=0)
|
||||||
|
|
||||||
|
with gr.Tabs(elem_id="extras_resize_mode"):
|
||||||
|
with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by:
|
||||||
|
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
|
||||||
|
|
||||||
|
with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
|
||||||
|
with FormRow():
|
||||||
|
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
|
||||||
|
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
|
||||||
|
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
|
extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
|
extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
||||||
|
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
|
||||||
|
|
||||||
|
tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
|
||||||
|
tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"upscale_mode": selected_tab,
|
||||||
|
"upscale_by": upscaling_resize,
|
||||||
|
"upscale_to_width": upscaling_resize_w,
|
||||||
|
"upscale_to_height": upscaling_resize_h,
|
||||||
|
"upscale_crop": upscaling_crop,
|
||||||
|
"upscaler_1_name": extras_upscaler_1,
|
||||||
|
"upscaler_2_name": extras_upscaler_2,
|
||||||
|
"upscaler_2_visibility": extras_upscaler_2_visibility,
|
||||||
|
}
|
||||||
|
|
||||||
|
def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop):
|
||||||
|
if upscale_mode == 1:
|
||||||
|
upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height)
|
||||||
|
info["Postprocess upscale to"] = f"{upscale_to_width}x{upscale_to_height}"
|
||||||
|
else:
|
||||||
|
info["Postprocess upscale by"] = upscale_by
|
||||||
|
|
||||||
|
cache_key = (hash(np.array(image.getdata()).tobytes()), upscaler.name, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
|
||||||
|
cached_image = upscale_cache.pop(cache_key, None)
|
||||||
|
|
||||||
|
if cached_image is not None:
|
||||||
|
image = cached_image
|
||||||
|
else:
|
||||||
|
image = upscaler.scaler.upscale(image, upscale_by, upscaler.data_path)
|
||||||
|
|
||||||
|
upscale_cache[cache_key] = image
|
||||||
|
if len(upscale_cache) > shared.opts.upscaling_max_images_in_cache:
|
||||||
|
upscale_cache.pop(next(iter(upscale_cache), None), None)
|
||||||
|
|
||||||
|
if upscale_mode == 1 and upscale_crop:
|
||||||
|
cropped = Image.new("RGB", (upscale_to_width, upscale_to_height))
|
||||||
|
cropped.paste(image, box=(upscale_to_width // 2 - image.width // 2, upscale_to_height // 2 - image.height // 2))
|
||||||
|
image = cropped
|
||||||
|
info["Postprocess crop to"] = f"{image.width}x{image.height}"
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
|
||||||
|
if upscaler_1_name == "None":
|
||||||
|
upscaler_1_name = None
|
||||||
|
|
||||||
|
upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_1_name]), None)
|
||||||
|
assert upscaler1 or (upscaler_1_name is None), f'could not find upscaler named {upscaler_1_name}'
|
||||||
|
|
||||||
|
if not upscaler1:
|
||||||
|
return
|
||||||
|
|
||||||
|
if upscaler_2_name == "None":
|
||||||
|
upscaler_2_name = None
|
||||||
|
|
||||||
|
upscaler2 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_2_name and x.name != "None"]), None)
|
||||||
|
assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}'
|
||||||
|
|
||||||
|
upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
|
||||||
|
pp.info[f"Postprocess upscaler"] = upscaler1.name
|
||||||
|
|
||||||
|
if upscaler2 and upscaler_2_visibility > 0:
|
||||||
|
second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
|
||||||
|
upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility)
|
||||||
|
|
||||||
|
pp.info[f"Postprocess upscaler 2"] = upscaler2.name
|
||||||
|
|
||||||
|
pp.image = upscaled_image
|
||||||
|
|
||||||
|
def image_changed(self):
|
||||||
|
upscale_cache.clear()
|
Loading…
Reference in New Issue
Block a user