Merge pull request #135 from rewbs/img2img2-color-correction
Add color correction to img2img loopback to avoid a progressive skew to magenta. Based on codedealer's PR to hlky's repo here: https://github.com/sd-webui/stable-diffusion-webui/pull/698/files.
This commit is contained in:
commit
9ddaf8269e
@ -1,4 +1,6 @@
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps, ImageChops
|
||||
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||
@ -59,8 +61,19 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
|
||||
|
||||
state.job_count = n_iter
|
||||
|
||||
do_color_correction = False
|
||||
try:
|
||||
from skimage import exposure
|
||||
do_color_correction = True
|
||||
except:
|
||||
print("Install scikit-image to perform color correction on loopback")
|
||||
|
||||
|
||||
for i in range(n_iter):
|
||||
|
||||
if do_color_correction and i == 0:
|
||||
correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
|
||||
|
||||
p.n_iter = 1
|
||||
p.batch_size = 1
|
||||
p.do_not_save_grid = True
|
||||
@ -72,7 +85,19 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
|
||||
initial_seed = processed.seed
|
||||
initial_info = processed.info
|
||||
|
||||
p.init_images = [processed.images[0]]
|
||||
init_img = processed.images[0]
|
||||
|
||||
if do_color_correction and correction_target is not None:
|
||||
init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
|
||||
cv2.cvtColor(
|
||||
np.asarray(init_img),
|
||||
cv2.COLOR_RGB2LAB
|
||||
),
|
||||
correction_target,
|
||||
channel_axis=2
|
||||
), cv2.COLOR_LAB2RGB).astype("uint8"))
|
||||
|
||||
p.init_images = [init_img]
|
||||
p.seed = processed.seed + 1
|
||||
p.denoising_strength = max(p.denoising_strength * 0.95, 0.1)
|
||||
history.append(processed.images[0])
|
||||
|
@ -10,5 +10,6 @@ omegaconf
|
||||
pytorch_lightning
|
||||
diffusers
|
||||
invisible-watermark
|
||||
scikit-image
|
||||
git+https://github.com/crowsonkb/k-diffusion.git
|
||||
git+https://github.com/TencentARC/GFPGAN.git
|
||||
|
@ -8,3 +8,4 @@ torch
|
||||
transformers==4.19.2
|
||||
omegaconf==2.1.1
|
||||
pytorch_lightning==1.7.2
|
||||
scikit-image==0.19.2
|
||||
|
Loading…
Reference in New Issue
Block a user