Implementation for SD upscale.
This commit is contained in:
parent
9597b265ec
commit
4e0fdca2f4
11
README.md
11
README.md
@ -194,3 +194,14 @@ Using `()` in prompt decreases model's attention to enclosed words, and `[]` inc
|
|||||||
multiple modifiers:
|
multiple modifiers:
|
||||||
|
|
||||||
![](images/attention-3.jpg)
|
![](images/attention-3.jpg)
|
||||||
|
|
||||||
|
### SD upscale
|
||||||
|
Upscale image using RealESRGAN and then go through tiles of the result, improving them with img2img.
|
||||||
|
|
||||||
|
Original idea by: https://github.com/jquesnelle/txt2imghd. This is an independent implementation.
|
||||||
|
|
||||||
|
To use this feature, tick a checkbox in the img2img interface. Original
|
||||||
|
image will be upscaled to twice the original width and height, while width and height sliders
|
||||||
|
will specify the size of individual tiles. At the moment this method does not support batch size.
|
||||||
|
|
||||||
|
![](images/sd-upscale.jpg)
|
||||||
|
BIN
images/sd-upscale.jpg
Normal file
BIN
images/sd-upscale.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 712 KiB |
178
webui.py
178
webui.py
@ -85,11 +85,6 @@ try:
|
|||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
|
|
||||||
realesrgan_models = [
|
realesrgan_models = [
|
||||||
RealesrganModelInfo(
|
|
||||||
name="Real-ESRGAN 2x plus",
|
|
||||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
|
||||||
netscale=2, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
|
||||||
),
|
|
||||||
RealesrganModelInfo(
|
RealesrganModelInfo(
|
||||||
name="Real-ESRGAN 4x plus",
|
name="Real-ESRGAN 4x plus",
|
||||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
@ -100,6 +95,11 @@ try:
|
|||||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
netscale=4, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
netscale=4, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||||
),
|
),
|
||||||
|
RealesrganModelInfo(
|
||||||
|
name="Real-ESRGAN 2x plus",
|
||||||
|
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||||
|
netscale=2, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
||||||
|
),
|
||||||
]
|
]
|
||||||
have_realesrgan = True
|
have_realesrgan = True
|
||||||
except:
|
except:
|
||||||
@ -124,6 +124,7 @@ class Options:
|
|||||||
"verify_input": (True, "Check input, and produce warning if it's too long"),
|
"verify_input": (True, "Check input, and produce warning if it's too long"),
|
||||||
"enable_pnginfo": (True, "Save text information about generation parameters as chunks to png files"),
|
"enable_pnginfo": (True, "Save text information about generation parameters as chunks to png files"),
|
||||||
"prompt_matrix_add_to_start": (True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"),
|
"prompt_matrix_add_to_start": (True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"),
|
||||||
|
"sd_upscale_overlap": (64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", 0, 256, 16),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -289,6 +290,73 @@ def image_grid(imgs, batch_size, force_n_rows=None):
|
|||||||
return grid
|
return grid
|
||||||
|
|
||||||
|
|
||||||
|
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
|
||||||
|
|
||||||
|
|
||||||
|
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
||||||
|
w = image.width
|
||||||
|
h = image.height
|
||||||
|
|
||||||
|
now = tile_w - overlap # non-overlap width
|
||||||
|
noh = tile_h - overlap
|
||||||
|
|
||||||
|
cols = math.ceil((w - overlap) / now)
|
||||||
|
rows = math.ceil((h - overlap) / noh)
|
||||||
|
|
||||||
|
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
||||||
|
for row in range(rows):
|
||||||
|
row_images = []
|
||||||
|
|
||||||
|
y = row * noh
|
||||||
|
|
||||||
|
if y + tile_h >= h:
|
||||||
|
y = h - tile_h
|
||||||
|
|
||||||
|
for col in range(cols):
|
||||||
|
x = col * now
|
||||||
|
|
||||||
|
if x+tile_w >= w:
|
||||||
|
x = w - tile_w
|
||||||
|
|
||||||
|
tile = image.crop((x, y, x + tile_w, y + tile_h))
|
||||||
|
|
||||||
|
row_images.append([x, tile_w, tile])
|
||||||
|
|
||||||
|
grid.tiles.append([y, tile_h, row_images])
|
||||||
|
|
||||||
|
return grid
|
||||||
|
|
||||||
|
|
||||||
|
def combine_grid(grid):
|
||||||
|
def make_mask_image(r):
|
||||||
|
r = r * 255 / grid.overlap
|
||||||
|
r = r.astype(np.uint8)
|
||||||
|
return Image.fromarray(r, 'L')
|
||||||
|
|
||||||
|
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
||||||
|
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
||||||
|
|
||||||
|
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
|
||||||
|
for y, h, row in grid.tiles:
|
||||||
|
combined_row = Image.new("RGB", (grid.image_w, h))
|
||||||
|
for x, w, tile in row:
|
||||||
|
if x == 0:
|
||||||
|
combined_row.paste(tile, (0, 0))
|
||||||
|
continue
|
||||||
|
|
||||||
|
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
|
||||||
|
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
|
||||||
|
|
||||||
|
if y == 0:
|
||||||
|
combined_image.paste(combined_row, (0, 0))
|
||||||
|
continue
|
||||||
|
|
||||||
|
combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
|
||||||
|
combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
|
||||||
|
|
||||||
|
return combined_image
|
||||||
|
|
||||||
|
|
||||||
def draw_prompt_matrix(im, width, height, all_prompts):
|
def draw_prompt_matrix(im, width, height, all_prompts):
|
||||||
def wrap(text, d, font, line_length):
|
def wrap(text, d, font, line_length):
|
||||||
lines = ['']
|
lines = ['']
|
||||||
@ -491,6 +559,7 @@ class StableDiffuionModelHijack:
|
|||||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||||
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
def __init__(self, wrapped, embeddings):
|
def __init__(self, wrapped, embeddings):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -740,8 +809,6 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
|
|||||||
save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
|
save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
|
||||||
grid_count += 1
|
grid_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
torch_gc()
|
torch_gc()
|
||||||
return output_images, seed, infotext()
|
return output_images, seed, infotext()
|
||||||
|
|
||||||
@ -847,7 +914,7 @@ txt2img_interface = gr.Interface(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
|
def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
|
||||||
outpath = opts.outdir or "outputs/img2img-samples"
|
outpath = opts.outdir or "outputs/img2img-samples"
|
||||||
|
|
||||||
sampler = samplers_for_img2img[sampler_index].constructor(model)
|
sampler = samplers_for_img2img[sampler_index].constructor(model)
|
||||||
@ -894,7 +961,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
|
|||||||
func_sample=sample,
|
func_sample=sample,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
sampler_index=0,
|
sampler_index=sampler_index,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
n_iter=1,
|
n_iter=1,
|
||||||
steps=ddim_steps,
|
steps=ddim_steps,
|
||||||
@ -923,6 +990,59 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
|
|||||||
output_images = history
|
output_images = history
|
||||||
seed = initial_seed
|
seed = initial_seed
|
||||||
|
|
||||||
|
elif sd_upscale:
|
||||||
|
initial_seed = None
|
||||||
|
initial_info = None
|
||||||
|
|
||||||
|
img = upscale_with_realesrgan(init_img, RealESRGAN_upscaling=2, RealESRGAN_model_index=0)
|
||||||
|
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
|
grid = split_grid(img, tile_w=width, tile_h=height, overlap=opts.sd_upscale_overlap)
|
||||||
|
|
||||||
|
|
||||||
|
print(f"SD upscaling will process a total of {len(grid.tiles[0][2])}x{len(grid.tiles)} images.")
|
||||||
|
|
||||||
|
for y, h, row in grid.tiles:
|
||||||
|
for tiledata in row:
|
||||||
|
init_img = tiledata[2]
|
||||||
|
|
||||||
|
output_images, seed, info = process_images(
|
||||||
|
outpath=outpath,
|
||||||
|
func_init=init,
|
||||||
|
func_sample=sample,
|
||||||
|
prompt=prompt,
|
||||||
|
seed=seed,
|
||||||
|
sampler_index=sampler_index,
|
||||||
|
batch_size=1, # since process_images can't work with multiple different images we have to do this for now
|
||||||
|
n_iter=1,
|
||||||
|
steps=ddim_steps,
|
||||||
|
cfg_scale=cfg_scale,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
prompt_matrix=prompt_matrix,
|
||||||
|
use_GFPGAN=use_GFPGAN,
|
||||||
|
do_not_save_grid=True,
|
||||||
|
extra_generation_params={"Denoising Strength": denoising_strength},
|
||||||
|
)
|
||||||
|
|
||||||
|
if initial_seed is None:
|
||||||
|
initial_seed = seed
|
||||||
|
initial_info = info
|
||||||
|
|
||||||
|
seed += 1
|
||||||
|
|
||||||
|
tiledata[2] = output_images[0]
|
||||||
|
|
||||||
|
combined_image = combine_grid(grid)
|
||||||
|
|
||||||
|
grid_count = len(os.listdir(outpath)) - 1
|
||||||
|
save_image(combined_image, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=initial_info, short_filename=not opts.grid_extended_filename)
|
||||||
|
|
||||||
|
output_images = [combined_image]
|
||||||
|
seed = initial_seed
|
||||||
|
info = initial_info
|
||||||
|
|
||||||
else:
|
else:
|
||||||
output_images, seed, info = process_images(
|
output_images, seed, info = process_images(
|
||||||
outpath=outpath,
|
outpath=outpath,
|
||||||
@ -930,7 +1050,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
|
|||||||
func_sample=sample,
|
func_sample=sample,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
sampler_index=0,
|
sampler_index=sampler_index,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
steps=ddim_steps,
|
steps=ddim_steps,
|
||||||
@ -960,6 +1080,7 @@ img2img_interface = gr.Interface(
|
|||||||
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
||||||
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
|
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
|
||||||
gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False),
|
gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False),
|
||||||
|
gr.Checkbox(label='Stable Diffusion upscale', value=False),
|
||||||
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
|
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
|
||||||
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
|
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
|
||||||
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
|
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
|
||||||
@ -978,7 +1099,26 @@ img2img_interface = gr.Interface(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
|
||||||
|
info = realesrgan_models[RealESRGAN_model_index]
|
||||||
|
|
||||||
|
model = info.model()
|
||||||
|
upsampler = RealESRGANer(
|
||||||
|
scale=info.netscale,
|
||||||
|
model_path=info.location,
|
||||||
|
model=model,
|
||||||
|
half=True
|
||||||
|
)
|
||||||
|
|
||||||
|
upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
|
||||||
|
|
||||||
|
image = Image.fromarray(upsampled)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_index):
|
def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_index):
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
|
||||||
outpath = opts.outdir or "outputs/extras-samples"
|
outpath = opts.outdir or "outputs/extras-samples"
|
||||||
@ -993,19 +1133,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in
|
|||||||
image = res
|
image = res
|
||||||
|
|
||||||
if have_realesrgan and RealESRGAN_upscaling != 1.0:
|
if have_realesrgan and RealESRGAN_upscaling != 1.0:
|
||||||
info = realesrgan_models[RealESRGAN_model_index]
|
image = upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
|
||||||
|
|
||||||
model = info.model()
|
|
||||||
upsampler = RealESRGANer(
|
|
||||||
scale=info.netscale,
|
|
||||||
model_path=info.location,
|
|
||||||
model=model,
|
|
||||||
half=True
|
|
||||||
)
|
|
||||||
|
|
||||||
upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
|
|
||||||
|
|
||||||
image = Image.fromarray(upsampled)
|
|
||||||
|
|
||||||
os.makedirs(outpath, exist_ok=True)
|
os.makedirs(outpath, exist_ok=True)
|
||||||
base_count = len(os.listdir(outpath))
|
base_count = len(os.listdir(outpath))
|
||||||
@ -1058,7 +1186,9 @@ def create_setting_component(key):
|
|||||||
if t == str:
|
if t == str:
|
||||||
item = gr.Textbox(label=label, value=fun, lines=1)
|
item = gr.Textbox(label=label, value=fun, lines=1)
|
||||||
elif t == int:
|
elif t == int:
|
||||||
if len(labelinfo) == 4:
|
if len(labelinfo) == 5:
|
||||||
|
item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=labelinfo[4], label=label, value=fun)
|
||||||
|
elif len(labelinfo) == 4:
|
||||||
item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=1, label=label, value=fun)
|
item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=1, label=label, value=fun)
|
||||||
else:
|
else:
|
||||||
item = gr.Number(label=label, value=fun)
|
item = gr.Number(label=label, value=fun)
|
||||||
|
Loading…
Reference in New Issue
Block a user