Implementation for SD upscale.
This commit is contained in:
parent
9597b265ec
commit
4e0fdca2f4
11
README.md
11
README.md
@ -194,3 +194,14 @@ Using `()` in prompt decreases model's attention to enclosed words, and `[]` inc
|
||||
multiple modifiers:
|
||||
|
||||
![](images/attention-3.jpg)
|
||||
|
||||
### SD upscale
|
||||
Upscale image using RealESRGAN and then go through tiles of the result, improving them with img2img.
|
||||
|
||||
Original idea by: https://github.com/jquesnelle/txt2imghd. This is an independent implementation.
|
||||
|
||||
To use this feature, tick a checkbox in the img2img interface. Original
|
||||
image will be upscaled to twice the original width and height, while width and height sliders
|
||||
will specify the size of individual tiles. At the moment this method does not support batch size.
|
||||
|
||||
![](images/sd-upscale.jpg)
|
||||
|
BIN
images/sd-upscale.jpg
Normal file
BIN
images/sd-upscale.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 712 KiB |
178
webui.py
178
webui.py
@ -85,11 +85,6 @@ try:
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
|
||||
realesrgan_models = [
|
||||
RealesrganModelInfo(
|
||||
name="Real-ESRGAN 2x plus",
|
||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
netscale=2, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
||||
),
|
||||
RealesrganModelInfo(
|
||||
name="Real-ESRGAN 4x plus",
|
||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
@ -100,6 +95,11 @@ try:
|
||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
netscale=4, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||
),
|
||||
RealesrganModelInfo(
|
||||
name="Real-ESRGAN 2x plus",
|
||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
netscale=2, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
||||
),
|
||||
]
|
||||
have_realesrgan = True
|
||||
except:
|
||||
@ -124,6 +124,7 @@ class Options:
|
||||
"verify_input": (True, "Check input, and produce warning if it's too long"),
|
||||
"enable_pnginfo": (True, "Save text information about generation parameters as chunks to png files"),
|
||||
"prompt_matrix_add_to_start": (True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"),
|
||||
"sd_upscale_overlap": (64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", 0, 256, 16),
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
@ -289,6 +290,73 @@ def image_grid(imgs, batch_size, force_n_rows=None):
|
||||
return grid
|
||||
|
||||
|
||||
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
|
||||
|
||||
|
||||
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
||||
w = image.width
|
||||
h = image.height
|
||||
|
||||
now = tile_w - overlap # non-overlap width
|
||||
noh = tile_h - overlap
|
||||
|
||||
cols = math.ceil((w - overlap) / now)
|
||||
rows = math.ceil((h - overlap) / noh)
|
||||
|
||||
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
||||
for row in range(rows):
|
||||
row_images = []
|
||||
|
||||
y = row * noh
|
||||
|
||||
if y + tile_h >= h:
|
||||
y = h - tile_h
|
||||
|
||||
for col in range(cols):
|
||||
x = col * now
|
||||
|
||||
if x+tile_w >= w:
|
||||
x = w - tile_w
|
||||
|
||||
tile = image.crop((x, y, x + tile_w, y + tile_h))
|
||||
|
||||
row_images.append([x, tile_w, tile])
|
||||
|
||||
grid.tiles.append([y, tile_h, row_images])
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def combine_grid(grid):
|
||||
def make_mask_image(r):
|
||||
r = r * 255 / grid.overlap
|
||||
r = r.astype(np.uint8)
|
||||
return Image.fromarray(r, 'L')
|
||||
|
||||
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
||||
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
||||
|
||||
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
|
||||
for y, h, row in grid.tiles:
|
||||
combined_row = Image.new("RGB", (grid.image_w, h))
|
||||
for x, w, tile in row:
|
||||
if x == 0:
|
||||
combined_row.paste(tile, (0, 0))
|
||||
continue
|
||||
|
||||
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
|
||||
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
|
||||
|
||||
if y == 0:
|
||||
combined_image.paste(combined_row, (0, 0))
|
||||
continue
|
||||
|
||||
combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
|
||||
combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
|
||||
|
||||
return combined_image
|
||||
|
||||
|
||||
def draw_prompt_matrix(im, width, height, all_prompts):
|
||||
def wrap(text, d, font, line_length):
|
||||
lines = ['']
|
||||
@ -491,6 +559,7 @@ class StableDiffuionModelHijack:
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
def __init__(self, wrapped, embeddings):
|
||||
super().__init__()
|
||||
@ -740,8 +809,6 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
|
||||
save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
|
||||
grid_count += 1
|
||||
|
||||
|
||||
|
||||
torch_gc()
|
||||
return output_images, seed, infotext()
|
||||
|
||||
@ -847,7 +914,7 @@ txt2img_interface = gr.Interface(
|
||||
)
|
||||
|
||||
|
||||
def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
|
||||
def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
|
||||
outpath = opts.outdir or "outputs/img2img-samples"
|
||||
|
||||
sampler = samplers_for_img2img[sampler_index].constructor(model)
|
||||
@ -894,7 +961,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
|
||||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_index=0,
|
||||
sampler_index=sampler_index,
|
||||
batch_size=1,
|
||||
n_iter=1,
|
||||
steps=ddim_steps,
|
||||
@ -923,6 +990,59 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
|
||||
output_images = history
|
||||
seed = initial_seed
|
||||
|
||||
elif sd_upscale:
|
||||
initial_seed = None
|
||||
initial_info = None
|
||||
|
||||
img = upscale_with_realesrgan(init_img, RealESRGAN_upscaling=2, RealESRGAN_model_index=0)
|
||||
|
||||
torch_gc()
|
||||
|
||||
grid = split_grid(img, tile_w=width, tile_h=height, overlap=opts.sd_upscale_overlap)
|
||||
|
||||
|
||||
print(f"SD upscaling will process a total of {len(grid.tiles[0][2])}x{len(grid.tiles)} images.")
|
||||
|
||||
for y, h, row in grid.tiles:
|
||||
for tiledata in row:
|
||||
init_img = tiledata[2]
|
||||
|
||||
output_images, seed, info = process_images(
|
||||
outpath=outpath,
|
||||
func_init=init,
|
||||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_index=sampler_index,
|
||||
batch_size=1, # since process_images can't work with multiple different images we have to do this for now
|
||||
n_iter=1,
|
||||
steps=ddim_steps,
|
||||
cfg_scale=cfg_scale,
|
||||
width=width,
|
||||
height=height,
|
||||
prompt_matrix=prompt_matrix,
|
||||
use_GFPGAN=use_GFPGAN,
|
||||
do_not_save_grid=True,
|
||||
extra_generation_params={"Denoising Strength": denoising_strength},
|
||||
)
|
||||
|
||||
if initial_seed is None:
|
||||
initial_seed = seed
|
||||
initial_info = info
|
||||
|
||||
seed += 1
|
||||
|
||||
tiledata[2] = output_images[0]
|
||||
|
||||
combined_image = combine_grid(grid)
|
||||
|
||||
grid_count = len(os.listdir(outpath)) - 1
|
||||
save_image(combined_image, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=initial_info, short_filename=not opts.grid_extended_filename)
|
||||
|
||||
output_images = [combined_image]
|
||||
seed = initial_seed
|
||||
info = initial_info
|
||||
|
||||
else:
|
||||
output_images, seed, info = process_images(
|
||||
outpath=outpath,
|
||||
@ -930,7 +1050,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
|
||||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_index=0,
|
||||
sampler_index=sampler_index,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=ddim_steps,
|
||||
@ -960,6 +1080,7 @@ img2img_interface = gr.Interface(
|
||||
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
||||
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
|
||||
gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False),
|
||||
gr.Checkbox(label='Stable Diffusion upscale', value=False),
|
||||
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
|
||||
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
|
||||
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
|
||||
@ -978,7 +1099,26 @@ img2img_interface = gr.Interface(
|
||||
)
|
||||
|
||||
|
||||
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
|
||||
info = realesrgan_models[RealESRGAN_model_index]
|
||||
|
||||
model = info.model()
|
||||
upsampler = RealESRGANer(
|
||||
scale=info.netscale,
|
||||
model_path=info.location,
|
||||
model=model,
|
||||
half=True
|
||||
)
|
||||
|
||||
upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
|
||||
|
||||
image = Image.fromarray(upsampled)
|
||||
return image
|
||||
|
||||
|
||||
def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_index):
|
||||
torch_gc()
|
||||
|
||||
image = image.convert("RGB")
|
||||
|
||||
outpath = opts.outdir or "outputs/extras-samples"
|
||||
@ -993,19 +1133,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in
|
||||
image = res
|
||||
|
||||
if have_realesrgan and RealESRGAN_upscaling != 1.0:
|
||||
info = realesrgan_models[RealESRGAN_model_index]
|
||||
|
||||
model = info.model()
|
||||
upsampler = RealESRGANer(
|
||||
scale=info.netscale,
|
||||
model_path=info.location,
|
||||
model=model,
|
||||
half=True
|
||||
)
|
||||
|
||||
upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
|
||||
|
||||
image = Image.fromarray(upsampled)
|
||||
image = upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
|
||||
|
||||
os.makedirs(outpath, exist_ok=True)
|
||||
base_count = len(os.listdir(outpath))
|
||||
@ -1058,7 +1186,9 @@ def create_setting_component(key):
|
||||
if t == str:
|
||||
item = gr.Textbox(label=label, value=fun, lines=1)
|
||||
elif t == int:
|
||||
if len(labelinfo) == 4:
|
||||
if len(labelinfo) == 5:
|
||||
item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=labelinfo[4], label=label, value=fun)
|
||||
elif len(labelinfo) == 4:
|
||||
item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=1, label=label, value=fun)
|
||||
else:
|
||||
item = gr.Number(label=label, value=fun)
|
||||
|
Loading…
Reference in New Issue
Block a user