2022-09-03 12:08:45 +03:00
import json
import math
import os
import sys
2022-11-19 12:01:51 +03:00
import warnings
2022-09-03 12:08:45 +03:00
import torch
import numpy as np
from PIL import Image , ImageFilter , ImageOps
import random
2022-09-13 12:51:57 +03:00
import cv2
from skimage import exposure
2022-10-17 19:10:36 +00:00
from typing import Any , Dict , List , Optional
2022-09-03 12:08:45 +03:00
2022-09-05 03:25:37 +03:00
import modules . sd_hijack
2022-10-21 16:10:51 +03:00
from modules import devices , prompt_parser , masking , sd_samplers , lowvram , generation_parameters_copypaste
2022-09-03 12:08:45 +03:00
from modules . sd_hijack import model_hijack
from modules . shared import opts , cmd_opts , state
import modules . shared as shared
2022-09-07 12:32:28 +03:00
import modules . face_restoration
2022-09-03 12:08:45 +03:00
import modules . images as images
2022-09-09 23:16:02 +03:00
import modules . styles
2022-09-23 00:57:42 +00:00
import logging
2022-09-03 12:08:45 +03:00
2022-09-13 12:51:57 +03:00
2022-09-03 12:08:45 +03:00
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8
2022-09-13 12:51:57 +03:00
def setup_color_correction ( image ) :
2022-09-23 00:57:42 +00:00
logging . info ( " Calibrating color correction. " )
2022-09-13 12:51:57 +03:00
correction_target = cv2 . cvtColor ( np . asarray ( image . copy ( ) ) , cv2 . COLOR_RGB2LAB )
return correction_target
def apply_color_correction ( correction , image ) :
2022-09-23 00:57:42 +00:00
logging . info ( " Applying color correction. " )
2022-09-13 12:51:57 +03:00
image = Image . fromarray ( cv2 . cvtColor ( exposure . match_histograms (
cv2 . cvtColor (
np . asarray ( image ) ,
cv2 . COLOR_RGB2LAB
) ,
correction ,
channel_axis = 2
) , cv2 . COLOR_LAB2RGB ) . astype ( " uint8 " ) )
return image
2022-10-24 09:15:26 +03:00
def apply_overlay ( image , paste_loc , index , overlays ) :
if overlays is None or index > = len ( overlays ) :
return image
overlay = overlays [ index ]
if paste_loc is not None :
x , y , w , h = paste_loc
base_image = Image . new ( ' RGBA ' , ( overlay . width , overlay . height ) )
image = images . resize_image ( 1 , image , w , h )
base_image . paste ( image , ( x , y ) )
image = base_image
image = image . convert ( ' RGBA ' )
image . alpha_composite ( overlay )
image = image . convert ( ' RGB ' )
2022-10-23 22:38:42 +03:00
return image
2022-09-13 12:51:57 +03:00
2022-10-08 18:13:13 -06:00
2022-10-17 19:10:36 +00:00
class StableDiffusionProcessing ( ) :
"""
The first set of paramaters : sd_models - > do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
2022-11-19 12:01:51 +03:00
def __init__ ( self , sd_model = None , outpath_samples = None , outpath_grids = None , prompt : str = " " , styles : List [ str ] = None , seed : int = - 1 , subseed : int = - 1 , subseed_strength : float = 0 , seed_resize_from_h : int = - 1 , seed_resize_from_w : int = - 1 , seed_enable_extras : bool = True , sampler_name : str = None , batch_size : int = 1 , n_iter : int = 1 , steps : int = 50 , cfg_scale : float = 7.0 , width : int = 512 , height : int = 512 , restore_faces : bool = False , tiling : bool = False , do_not_save_samples : bool = False , do_not_save_grid : bool = False , extra_generation_params : Dict [ Any , Any ] = None , overlay_images : Any = None , negative_prompt : str = None , eta : float = None , do_not_reload_embeddings : bool = False , denoising_strength : float = 0 , ddim_discretize : str = None , s_churn : float = 0.0 , s_tmax : float = None , s_tmin : float = 0.0 , s_noise : float = 1.0 , override_settings : Dict [ str , Any ] = None , sampler_index : int = None ) :
if sampler_index is not None :
2022-11-27 13:17:39 +03:00
print ( " sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name " , file = sys . stderr )
2022-11-19 12:01:51 +03:00
2022-09-03 12:08:45 +03:00
self . sd_model = sd_model
self . outpath_samples : str = outpath_samples
self . outpath_grids : str = outpath_grids
self . prompt : str = prompt
2022-09-03 17:21:15 +03:00
self . prompt_for_display : str = None
2022-09-03 12:08:45 +03:00
self . negative_prompt : str = ( negative_prompt or " " )
2022-10-02 15:03:39 +03:00
self . styles : list = styles or [ ]
2022-09-03 12:08:45 +03:00
self . seed : int = seed
2022-09-09 17:54:04 +03:00
self . subseed : int = subseed
self . subseed_strength : float = subseed_strength
self . seed_resize_from_h : int = seed_resize_from_h
self . seed_resize_from_w : int = seed_resize_from_w
2022-11-19 12:01:51 +03:00
self . sampler_name : str = sampler_name
2022-09-03 12:08:45 +03:00
self . batch_size : int = batch_size
self . n_iter : int = n_iter
self . steps : int = steps
self . cfg_scale : float = cfg_scale
self . width : int = width
self . height : int = height
2022-09-07 12:32:28 +03:00
self . restore_faces : bool = restore_faces
2022-09-05 03:25:37 +03:00
self . tiling : bool = tiling
2022-09-03 12:08:45 +03:00
self . do_not_save_samples : bool = do_not_save_samples
self . do_not_save_grid : bool = do_not_save_grid
2022-09-20 19:07:09 +03:00
self . extra_generation_params : dict = extra_generation_params or { }
2022-09-03 12:08:45 +03:00
self . overlay_images = overlay_images
2022-09-28 18:09:06 +03:00
self . eta = eta
2022-10-16 08:51:24 +03:00
self . do_not_reload_embeddings = do_not_reload_embeddings
2022-09-03 12:08:45 +03:00
self . paste_to = None
2022-09-13 12:51:57 +03:00
self . color_corrections = None
2022-10-26 11:47:07 +03:00
self . denoising_strength : float = denoising_strength
2022-09-30 01:44:38 +01:00
self . sampler_noise_scheduler_override = None
2022-10-26 11:47:07 +03:00
self . ddim_discretize = ddim_discretize or opts . ddim_discretize
2022-10-17 19:10:36 +00:00
self . s_churn = s_churn or opts . s_churn
self . s_tmin = s_tmin or opts . s_tmin
self . s_tmax = s_tmax or float ( ' inf ' ) # not representable as a standard ui option
self . s_noise = s_noise or opts . s_noise
2022-10-26 11:47:07 +03:00
self . override_settings = { k : v for k , v in ( override_settings or { } ) . items ( ) if k not in shared . restricted_opts }
2022-11-19 12:47:52 +03:00
self . is_using_inpainting_conditioning = False
2022-10-04 09:49:51 -06:00
2022-09-21 13:34:10 +03:00
if not seed_enable_extras :
self . subseed = - 1
self . subseed_strength = 0
self . seed_resize_from_h = 0
self . seed_resize_from_w = 0
2022-10-22 12:23:45 +03:00
self . scripts = None
self . script_args = None
self . all_prompts = None
2022-11-19 13:23:25 +03:00
self . all_negative_prompts = None
2022-10-22 12:23:45 +03:00
self . all_seeds = None
self . all_subseeds = None
2022-10-27 11:27:59 -07:00
def txt2img_image_conditioning ( self , x , width = None , height = None ) :
if self . sampler . conditioning_key not in { ' hybrid ' , ' concat ' } :
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
2022-10-29 10:35:51 -07:00
return x . new_zeros ( x . shape [ 0 ] , 5 , 1 , 1 )
2022-10-27 11:27:59 -07:00
2022-11-19 12:47:52 +03:00
self . is_using_inpainting_conditioning = True
2022-10-27 11:27:59 -07:00
height = height or self . height
width = width or self . width
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch . zeros ( x . shape [ 0 ] , 3 , height , width , device = x . device )
image_conditioning = self . sd_model . get_first_stage_encoding ( self . sd_model . encode_first_stage ( image_conditioning ) )
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch . nn . functional . pad ( image_conditioning , ( 0 , 0 , 0 , 0 , 1 , 0 ) , value = 1.0 )
image_conditioning = image_conditioning . to ( x . dtype )
return image_conditioning
def img2img_image_conditioning ( self , source_image , latent_image , image_mask = None ) :
if self . sampler . conditioning_key not in { ' hybrid ' , ' concat ' } :
# Dummy zero conditioning if we're not using inpainting model.
2022-10-29 10:35:51 -07:00
return latent_image . new_zeros ( latent_image . shape [ 0 ] , 5 , 1 , 1 )
2022-10-27 11:27:59 -07:00
2022-11-19 12:47:52 +03:00
self . is_using_inpainting_conditioning = True
2022-10-27 11:27:59 -07:00
# Handle the different mask inputs
if image_mask is not None :
if torch . is_tensor ( image_mask ) :
conditioning_mask = image_mask
else :
conditioning_mask = np . array ( image_mask . convert ( " L " ) )
conditioning_mask = conditioning_mask . astype ( np . float32 ) / 255.0
conditioning_mask = torch . from_numpy ( conditioning_mask [ None , None ] )
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch . round ( conditioning_mask )
else :
2022-10-29 10:35:51 -07:00
conditioning_mask = source_image . new_ones ( 1 , 1 , * source_image . shape [ - 2 : ] )
2022-10-27 11:27:59 -07:00
# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
2022-10-29 14:13:02 -07:00
conditioning_mask = conditioning_mask . to ( source_image . device ) . to ( source_image . dtype )
2022-10-27 11:27:59 -07:00
conditioning_image = torch . lerp (
source_image ,
source_image * ( 1.0 - conditioning_mask ) ,
getattr ( self , " inpainting_mask_weight " , shared . opts . inpainting_mask_weight )
)
# Encode the new masked image using first stage of network.
conditioning_image = self . sd_model . get_first_stage_encoding ( self . sd_model . encode_first_stage ( conditioning_image ) )
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch . nn . functional . interpolate ( conditioning_mask , size = latent_image . shape [ - 2 : ] )
conditioning_mask = conditioning_mask . expand ( conditioning_image . shape [ 0 ] , - 1 , - 1 , - 1 )
image_conditioning = torch . cat ( [ conditioning_mask , conditioning_image ] , dim = 1 )
image_conditioning = image_conditioning . to ( shared . device ) . type ( self . sd_model . dtype )
return image_conditioning
2022-09-19 16:42:56 +03:00
def init ( self , all_prompts , all_seeds , all_subseeds ) :
2022-09-03 12:08:45 +03:00
pass
2022-11-02 12:45:03 +03:00
def sample ( self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , prompts ) :
2022-09-03 12:08:45 +03:00
raise NotImplementedError ( )
2022-11-01 21:56:47 -03:00
def close ( self ) :
self . sd_model = None
self . sampler = None
2022-09-03 12:08:45 +03:00
class Processed :
2022-11-19 13:23:25 +03:00
def __init__ ( self , p : StableDiffusionProcessing , images_list , seed = - 1 , info = " " , subseed = None , all_prompts = None , all_negative_prompts = None , all_seeds = None , all_subseeds = None , index_of_first_image = 0 , infotexts = None ) :
2022-09-03 12:08:45 +03:00
self . images = images_list
self . prompt = p . prompt
2022-09-12 19:57:31 +03:00
self . negative_prompt = p . negative_prompt
2022-09-03 12:08:45 +03:00
self . seed = seed
2022-09-16 21:20:56 +02:00
self . subseed = subseed
self . subseed_strength = p . subseed_strength
2022-09-03 12:08:45 +03:00
self . info = info
self . width = p . width
self . height = p . height
2022-11-19 12:01:51 +03:00
self . sampler_name = p . sampler_name
2022-09-03 12:08:45 +03:00
self . cfg_scale = p . cfg_scale
self . steps = p . steps
2022-09-19 09:02:10 +03:00
self . batch_size = p . batch_size
self . restore_faces = p . restore_faces
self . face_restoration_model = opts . face_restoration_model if p . restore_faces else None
self . sd_model_hash = shared . sd_model . sd_model_hash
self . seed_resize_from_w = p . seed_resize_from_w
self . seed_resize_from_h = p . seed_resize_from_h
self . denoising_strength = getattr ( p , ' denoising_strength ' , None )
self . extra_generation_params = p . extra_generation_params
self . index_of_first_image = index_of_first_image
2022-10-05 02:13:09 +09:00
self . styles = p . styles
2022-10-05 02:17:15 +09:00
self . job_timestamp = state . job_timestamp
2022-10-08 17:28:42 -04:00
self . clip_skip = opts . CLIP_stop_at_last_layers
2022-09-19 09:02:10 +03:00
2022-09-28 05:11:03 +03:00
self . eta = p . eta
2022-09-26 15:40:47 +01:00
self . ddim_discretize = p . ddim_discretize
self . s_churn = p . s_churn
self . s_tmin = p . s_tmin
self . s_tmax = p . s_tmax
self . s_noise = p . s_noise
2022-09-30 01:44:38 +01:00
self . sampler_noise_scheduler_override = p . sampler_noise_scheduler_override
2022-09-19 09:02:10 +03:00
self . prompt = self . prompt if type ( self . prompt ) != list else self . prompt [ 0 ]
self . negative_prompt = self . negative_prompt if type ( self . negative_prompt ) != list else self . negative_prompt [ 0 ]
2022-10-13 20:05:07 -07:00
self . seed = int ( self . seed if type ( self . seed ) != list else self . seed [ 0 ] ) if self . seed is not None else - 1
2022-09-19 09:02:10 +03:00
self . subseed = int ( self . subseed if type ( self . subseed ) != list else self . subseed [ 0 ] ) if self . subseed is not None else - 1
2022-11-19 12:47:52 +03:00
self . is_using_inpainting_conditioning = p . is_using_inpainting_conditioning
2022-09-19 09:02:10 +03:00
2022-11-19 13:23:25 +03:00
self . all_prompts = all_prompts or p . all_prompts or [ self . prompt ]
self . all_negative_prompts = all_negative_prompts or p . all_negative_prompts or [ self . negative_prompt ]
self . all_seeds = all_seeds or p . all_seeds or [ self . seed ]
self . all_subseeds = all_subseeds or p . all_subseeds or [ self . subseed ]
2022-09-28 17:05:23 +03:00
self . infotexts = infotexts or [ info ]
2022-09-03 12:08:45 +03:00
def js ( self ) :
obj = {
2022-11-19 13:23:25 +03:00
" prompt " : self . all_prompts [ 0 ] ,
2022-09-19 09:02:10 +03:00
" all_prompts " : self . all_prompts ,
2022-11-19 13:23:25 +03:00
" negative_prompt " : self . all_negative_prompts [ 0 ] ,
" all_negative_prompts " : self . all_negative_prompts ,
2022-09-19 09:02:10 +03:00
" seed " : self . seed ,
" all_seeds " : self . all_seeds ,
" subseed " : self . subseed ,
" all_subseeds " : self . all_subseeds ,
2022-09-16 21:20:56 +02:00
" subseed_strength " : self . subseed_strength ,
2022-09-03 12:08:45 +03:00
" width " : self . width ,
" height " : self . height ,
2022-11-19 12:01:51 +03:00
" sampler_name " : self . sampler_name ,
2022-09-03 12:08:45 +03:00
" cfg_scale " : self . cfg_scale ,
" steps " : self . steps ,
2022-09-19 09:02:10 +03:00
" batch_size " : self . batch_size ,
" restore_faces " : self . restore_faces ,
" face_restoration_model " : self . face_restoration_model ,
" sd_model_hash " : self . sd_model_hash ,
" seed_resize_from_w " : self . seed_resize_from_w ,
" seed_resize_from_h " : self . seed_resize_from_h ,
" denoising_strength " : self . denoising_strength ,
" extra_generation_params " : self . extra_generation_params ,
" index_of_first_image " : self . index_of_first_image ,
2022-09-28 17:05:23 +03:00
" infotexts " : self . infotexts ,
2022-10-05 02:13:09 +09:00
" styles " : self . styles ,
2022-10-05 02:17:15 +09:00
" job_timestamp " : self . job_timestamp ,
2022-10-08 22:21:15 +03:00
" clip_skip " : self . clip_skip ,
2022-11-19 12:47:52 +03:00
" is_using_inpainting_conditioning " : self . is_using_inpainting_conditioning ,
2022-09-03 12:08:45 +03:00
}
return json . dumps ( obj )
2022-09-19 09:02:10 +03:00
def infotext ( self , p : StableDiffusionProcessing , index ) :
return create_infotext ( p , self . all_prompts , self . all_seeds , self . all_subseeds , comments = [ ] , position_in_batch = index % self . batch_size , iteration = index / / self . batch_size )
2022-09-09 17:54:04 +03:00
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp ( val , low , high ) :
low_norm = low / torch . norm ( low , dim = 1 , keepdim = True )
high_norm = high / torch . norm ( high , dim = 1 , keepdim = True )
2022-09-18 20:44:57 +03:00
dot = ( low_norm * high_norm ) . sum ( 1 )
if dot . mean ( ) > 0.9995 :
return low * val + high * ( 1 - val )
omega = torch . acos ( dot )
2022-09-09 17:54:04 +03:00
so = torch . sin ( omega )
res = ( torch . sin ( ( 1.0 - val ) * omega ) / so ) . unsqueeze ( 1 ) * low + ( torch . sin ( val * omega ) / so ) . unsqueeze ( 1 ) * high
return res
2022-09-03 12:08:45 +03:00
2022-09-09 17:54:04 +03:00
2022-09-13 21:49:58 +03:00
def create_random_tensors ( shape , seeds , subseeds = None , subseed_strength = 0.0 , seed_resize_from_h = 0 , seed_resize_from_w = 0 , p = None ) :
2022-09-03 12:08:45 +03:00
xs = [ ]
2022-09-13 21:49:58 +03:00
2022-09-16 10:04:07 +03:00
# if we have multiple seeds, this means we are working with batch size>1; this then
# enables the generation of additional tensors with noise that the sampler will use during its processing.
2022-09-17 22:09:52 -07:00
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
2022-09-16 10:04:07 +03:00
# produce the same images as with two batches [100], [101].
2022-10-10 20:32:37 +03:00
if p is not None and p . sampler is not None and ( len ( seeds ) > 1 and opts . enable_batch_seeds or opts . eta_noise_seed_delta > 0 ) :
2022-09-13 21:49:58 +03:00
sampler_noises = [ [ ] for _ in range ( p . sampler . number_of_needed_noises ( p ) ) ]
else :
sampler_noises = None
2022-09-09 17:54:04 +03:00
for i , seed in enumerate ( seeds ) :
noise_shape = shape if seed_resize_from_h < = 0 or seed_resize_from_w < = 0 else ( shape [ 0 ] , seed_resize_from_h / / 8 , seed_resize_from_w / / 8 )
subnoise = None
if subseeds is not None :
subseed = 0 if i > = len ( subseeds ) else subseeds [ i ]
2022-09-12 16:32:44 +02:00
2022-09-12 20:09:32 +03:00
subnoise = devices . randn ( subseed , noise_shape )
2022-09-03 12:08:45 +03:00
# randn results depend on device; gpu and cpu get different results for same seed;
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
2022-09-09 17:54:04 +03:00
# but the original script had it like this, so I do not dare change it for now because
2022-09-03 12:08:45 +03:00
# it will break everyone's seeds.
2022-09-12 20:09:32 +03:00
noise = devices . randn ( seed , noise_shape )
2022-09-09 17:54:04 +03:00
if subnoise is not None :
noise = slerp ( subseed_strength , noise , subnoise )
if noise_shape != shape :
2022-09-12 20:09:32 +03:00
x = devices . randn ( seed , shape )
dx = ( shape [ 2 ] - noise_shape [ 2 ] ) / / 2
2022-09-09 17:54:04 +03:00
dy = ( shape [ 1 ] - noise_shape [ 1 ] ) / / 2
w = noise_shape [ 2 ] if dx > = 0 else noise_shape [ 2 ] + 2 * dx
h = noise_shape [ 1 ] if dy > = 0 else noise_shape [ 1 ] + 2 * dy
tx = 0 if dx < 0 else dx
ty = 0 if dy < 0 else dy
dx = max ( - dx , 0 )
dy = max ( - dy , 0 )
x [ : , ty : ty + h , tx : tx + w ] = noise [ : , dy : dy + h , dx : dx + w ]
noise = x
2022-09-13 21:49:58 +03:00
if sampler_noises is not None :
cnt = p . sampler . number_of_needed_noises ( p )
2022-09-09 17:54:04 +03:00
2022-10-10 20:32:37 +03:00
if opts . eta_noise_seed_delta > 0 :
torch . manual_seed ( seed + opts . eta_noise_seed_delta )
2022-09-13 21:49:58 +03:00
for j in range ( cnt ) :
sampler_noises [ j ] . append ( devices . randn_without_seed ( tuple ( noise_shape ) ) )
2022-09-09 17:54:04 +03:00
xs . append ( noise )
2022-09-13 21:49:58 +03:00
if sampler_noises is not None :
p . sampler . sampler_noises = [ torch . stack ( n ) . to ( shared . device ) for n in sampler_noises ]
2022-09-09 17:54:04 +03:00
x = torch . stack ( xs ) . to ( shared . device )
2022-09-03 12:08:45 +03:00
return x
2022-10-10 16:11:14 +03:00
def decode_first_stage ( model , x ) :
with devices . autocast ( disable = x . dtype == devices . dtype_vae ) :
x = model . decode_first_stage ( x )
return x
2022-10-04 17:36:39 +03:00
def get_fixed_seed ( seed ) :
if seed is None or seed == ' ' or seed == - 1 :
return int ( random . randrange ( 4294967294 ) )
return seed
2022-09-09 17:54:04 +03:00
def fix_seed ( p ) :
2022-10-04 17:36:39 +03:00
p . seed = get_fixed_seed ( p . seed )
p . subseed = get_fixed_seed ( p . subseed )
2022-09-07 01:44:44 +03:00
2022-09-19 09:02:10 +03:00
def create_infotext ( p , all_prompts , all_seeds , all_subseeds , comments , iteration = 0 , position_in_batch = 0 ) :
index = position_in_batch + iteration * p . batch_size
2022-10-08 17:28:42 -04:00
clip_skip = getattr ( p , ' clip_skip ' , opts . CLIP_stop_at_last_layers )
2022-10-08 22:21:15 +03:00
2022-09-19 09:02:10 +03:00
generation_params = {
" Steps " : p . steps ,
2022-11-19 12:01:51 +03:00
" Sampler " : p . sampler_name ,
2022-09-19 09:02:10 +03:00
" CFG scale " : p . cfg_scale ,
" Seed " : all_seeds [ index ] ,
" Face restoration " : ( opts . face_restoration_model if p . restore_faces else None ) ,
" Size " : f " { p . width } x { p . height } " ,
" Model hash " : getattr ( p , ' sd_model_hash ' , None if not opts . add_model_hash_to_info or not shared . sd_model . sd_model_hash else shared . sd_model . sd_model_hash ) ,
2022-10-09 14:57:48 +03:00
" Model " : ( None if not opts . add_model_name_to_info or not shared . sd_model . sd_checkpoint_info . model_name else shared . sd_model . sd_checkpoint_info . model_name . replace ( ' , ' , ' ' ) . replace ( ' : ' , ' ' ) ) ,
2022-10-21 02:17:26 -07:00
" Hypernet " : ( None if shared . loaded_hypernetwork is None else shared . loaded_hypernetwork . name ) ,
2022-10-30 08:48:53 +03:00
" Hypernet strength " : ( None if shared . loaded_hypernetwork is None or shared . opts . sd_hypernetwork_strength > = 1 else shared . opts . sd_hypernetwork_strength ) ,
2022-09-19 09:02:10 +03:00
" Batch size " : ( None if p . batch_size < 2 else p . batch_size ) ,
" Batch pos " : ( None if p . batch_size < 2 else position_in_batch ) ,
" Variation seed " : ( None if p . subseed_strength == 0 else all_subseeds [ index ] ) ,
" Variation seed strength " : ( None if p . subseed_strength == 0 else p . subseed_strength ) ,
" Seed resize from " : ( None if p . seed_resize_from_w == 0 or p . seed_resize_from_h == 0 else f " { p . seed_resize_from_w } x { p . seed_resize_from_h } " ) ,
" Denoising strength " : getattr ( p , ' denoising_strength ' , None ) ,
2022-11-19 12:47:52 +03:00
" Conditional mask weight " : getattr ( p , " inpainting_mask_weight " , shared . opts . inpainting_mask_weight ) if p . is_using_inpainting_conditioning else None ,
2022-10-02 15:03:39 +03:00
" Eta " : ( None if p . sampler is None or p . sampler . eta == p . sampler . default_eta else p . sampler . eta ) ,
2022-10-09 22:30:59 +03:00
" Clip skip " : None if clip_skip < = 1 else clip_skip ,
2022-10-10 20:32:37 +03:00
" ENSD " : None if opts . eta_noise_seed_delta == 0 else opts . eta_noise_seed_delta ,
2022-09-19 09:02:10 +03:00
}
2022-09-20 19:07:09 +03:00
generation_params . update ( p . extra_generation_params )
2022-09-19 09:02:10 +03:00
2022-10-21 16:10:51 +03:00
generation_params_text = " , " . join ( [ k if k == v else f ' { k } : { generation_parameters_copypaste . quote ( v ) } ' for k , v in generation_params . items ( ) if v is not None ] )
2022-09-19 09:02:10 +03:00
2022-11-30 15:23:53 +02:00
negative_prompt_text = " \n Negative prompt: " + p . all_negative_prompts [ index ] if p . all_negative_prompts [ index ] else " "
2022-09-19 09:02:10 +03:00
2022-09-28 18:20:30 +03:00
return f " { all_prompts [ index ] } { negative_prompt_text } \n { generation_params_text } " . strip ( )
2022-09-19 09:02:10 +03:00
2022-09-03 12:08:45 +03:00
def process_images ( p : StableDiffusionProcessing ) - > Processed :
2022-10-26 11:47:07 +03:00
stored_opts = { k : opts . data [ k ] for k in p . override_settings . keys ( ) }
try :
for k , v in p . override_settings . items ( ) :
2022-11-06 16:31:44 +08:00
setattr ( opts , k , v ) # we don't call onchange for simplicity which makes changing model impossible
2022-11-01 23:26:55 +08:00
if k == ' sd_hypernetwork ' : shared . reload_hypernetworks ( ) # make onchange call for changing hypernet since it is relatively fast to load on-change, while SD models are not
2022-10-26 11:47:07 +03:00
res = process_images_inner ( p )
2022-11-01 23:26:55 +08:00
finally : # restore opts to original state
2022-10-26 11:47:07 +03:00
for k , v in stored_opts . items ( ) :
2022-11-04 09:42:25 +03:00
setattr ( opts , k , v )
2022-11-01 23:26:55 +08:00
if k == ' sd_hypernetwork ' : shared . reload_hypernetworks ( )
2022-10-26 11:47:07 +03:00
return res
def process_images_inner ( p : StableDiffusionProcessing ) - > Processed :
2022-09-03 12:08:45 +03:00
""" this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch """
2022-09-17 01:34:33 -07:00
if type ( p . prompt ) == list :
assert ( len ( p . prompt ) > 0 )
else :
assert p . prompt is not None
2022-10-04 09:49:51 -06:00
2022-09-11 23:24:24 +03:00
devices . torch_gc ( )
2022-09-03 12:08:45 +03:00
2022-10-04 17:36:39 +03:00
seed = get_fixed_seed ( p . seed )
subseed = get_fixed_seed ( p . subseed )
2022-09-03 12:08:45 +03:00
2022-09-05 03:25:37 +03:00
modules . sd_hijack . model_hijack . apply_circular ( p . tiling )
2022-10-07 17:48:34 -04:00
modules . sd_hijack . model_hijack . clear_comments ( )
2022-09-05 03:25:37 +03:00
2022-09-15 08:57:03 +03:00
comments = { }
2022-09-09 23:16:02 +03:00
if type ( p . prompt ) == list :
2022-11-19 13:23:25 +03:00
p . all_prompts = [ shared . prompt_styles . apply_styles_to_prompt ( x , p . styles ) for x in p . prompt ]
else :
p . all_prompts = p . batch_size * p . n_iter * [ shared . prompt_styles . apply_styles_to_prompt ( p . prompt , p . styles ) ]
if type ( p . negative_prompt ) == list :
p . all_negative_prompts = [ shared . prompt_styles . apply_negative_styles_to_prompt ( x , p . styles ) for x in p . negative_prompt ]
2022-09-03 12:08:45 +03:00
else :
2022-11-19 13:23:25 +03:00
p . all_negative_prompts = p . batch_size * p . n_iter * [ shared . prompt_styles . apply_negative_styles_to_prompt ( p . negative_prompt , p . styles ) ]
2022-09-03 17:21:15 +03:00
2022-10-04 17:36:39 +03:00
if type ( seed ) == list :
2022-10-22 12:23:45 +03:00
p . all_seeds = seed
2022-09-03 17:21:15 +03:00
else :
2022-10-22 12:23:45 +03:00
p . all_seeds = [ int ( seed ) + ( x if p . subseed_strength == 0 else 0 ) for x in range ( len ( p . all_prompts ) ) ]
2022-09-09 17:54:04 +03:00
2022-10-04 17:36:39 +03:00
if type ( subseed ) == list :
2022-10-22 12:23:45 +03:00
p . all_subseeds = subseed
2022-09-09 17:54:04 +03:00
else :
2022-10-22 12:23:45 +03:00
p . all_subseeds = [ int ( subseed ) + x for x in range ( len ( p . all_prompts ) ) ]
2022-09-03 12:08:45 +03:00
def infotext ( iteration = 0 , position_in_batch = 0 ) :
2022-10-22 12:23:45 +03:00
return create_infotext ( p , p . all_prompts , p . all_seeds , p . all_subseeds , comments , iteration , position_in_batch )
2022-09-03 12:08:45 +03:00
2022-11-19 13:23:25 +03:00
with open ( os . path . join ( shared . script_path , " params.txt " ) , " w " , encoding = " utf8 " ) as file :
processed = Processed ( p , [ ] , p . seed , " " )
file . write ( processed . infotext ( p , 0 ) )
2022-10-16 08:51:24 +03:00
if os . path . exists ( cmd_opts . embeddings_dir ) and not p . do_not_reload_embeddings :
2022-10-02 15:03:39 +03:00
model_hijack . embedding_db . load_textual_inversion_embeddings ( )
2022-09-03 12:08:45 +03:00
2022-10-22 12:23:45 +03:00
if p . scripts is not None :
2022-10-29 22:20:02 +03:00
p . scripts . process ( p )
2022-10-22 12:23:45 +03:00
2022-09-28 17:05:23 +03:00
infotexts = [ ]
2022-09-03 12:08:45 +03:00
output_images = [ ]
2022-10-04 12:32:22 +03:00
2022-10-08 23:26:48 +03:00
with torch . no_grad ( ) , p . sd_model . ema_scope ( ) :
2022-10-04 16:54:31 +03:00
with devices . autocast ( ) :
2022-10-22 12:23:45 +03:00
p . init ( p . all_prompts , p . all_seeds , p . all_subseeds )
2022-09-03 12:08:45 +03:00
2022-09-06 10:11:25 +03:00
if state . job_count == - 1 :
state . job_count = p . n_iter
2022-09-06 02:09:01 +03:00
2022-10-04 22:28:50 -03:00
for n in range ( p . n_iter ) :
2022-10-04 22:56:30 -05:00
if state . skipped :
state . skipped = False
2022-09-03 12:08:45 +03:00
if state . interrupted :
break
2022-10-22 12:23:45 +03:00
prompts = p . all_prompts [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
2022-11-19 13:23:25 +03:00
negative_prompts = p . all_negative_prompts [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
2022-10-22 12:23:45 +03:00
seeds = p . all_seeds [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
subseeds = p . all_subseeds [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
2022-09-03 12:08:45 +03:00
2022-10-29 22:20:02 +03:00
if len ( prompts ) == 0 :
2022-09-17 01:34:33 -07:00
break
2022-11-02 19:05:01 +03:00
if p . scripts is not None :
2022-11-04 11:21:40 +03:00
p . scripts . process_batch ( p , batch_number = n , prompts = prompts , seeds = seeds , subseeds = subseeds )
2022-11-02 19:05:01 +03:00
2022-10-04 12:32:22 +03:00
with devices . autocast ( ) :
2022-11-19 13:23:25 +03:00
uc = prompt_parser . get_learned_conditioning ( shared . sd_model , negative_prompts , p . steps )
2022-10-05 23:16:27 +03:00
c = prompt_parser . get_multicond_learned_conditioning ( shared . sd_model , prompts , p . steps )
2022-09-03 12:08:45 +03:00
if len ( model_hijack . comments ) > 0 :
2022-09-15 08:57:03 +03:00
for comment in model_hijack . comments :
comments [ comment ] = 1
2022-09-03 12:08:45 +03:00
if p . n_iter > 1 :
2022-09-24 08:23:01 +03:00
shared . state . job = f " Batch { n + 1 } out of { p . n_iter } "
2022-09-03 12:08:45 +03:00
2022-10-04 12:32:22 +03:00
with devices . autocast ( ) :
2022-11-02 12:45:03 +03:00
samples_ddim = p . sample ( conditioning = c , unconditional_conditioning = uc , seeds = seeds , subseeds = subseeds , subseed_strength = p . subseed_strength , prompts = prompts )
2022-10-04 12:32:22 +03:00
2022-11-28 16:29:43 +05:00
x_samples_ddim = [ decode_first_stage ( p . sd_model , samples_ddim [ i : i + 1 ] . to ( dtype = devices . dtype_vae ) ) [ 0 ] . cpu ( ) for i in range ( samples_ddim . size ( 0 ) ) ]
x_samples_ddim = torch . stack ( x_samples_ddim ) . float ( )
2022-09-03 12:08:45 +03:00
x_samples_ddim = torch . clamp ( ( x_samples_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
2022-09-28 22:14:13 -03:00
del samples_ddim
if shared . cmd_opts . lowvram or shared . cmd_opts . medvram :
lowvram . send_everything_to_cpu ( )
devices . torch_gc ( )
2022-09-12 19:15:35 -05:00
if opts . filter_nsfw :
2022-09-13 08:34:41 +03:00
import modules . safety as safety
x_samples_ddim = modules . safety . censor_batch ( x_samples_ddim )
2022-09-12 19:15:35 -05:00
2022-10-04 22:28:50 -03:00
for i , x_sample in enumerate ( x_samples_ddim ) :
2022-09-03 12:08:45 +03:00
x_sample = 255. * np . moveaxis ( x_sample . cpu ( ) . numpy ( ) , 0 , 2 )
x_sample = x_sample . astype ( np . uint8 )
2022-10-04 22:28:50 -03:00
if p . restore_faces :
2022-09-12 17:47:36 +03:00
if opts . save and not p . do_not_save_samples and opts . save_images_before_face_restoration :
2022-09-22 20:54:50 +10:00
images . save_image ( Image . fromarray ( x_sample ) , p . outpath_samples , " " , seeds [ i ] , prompts [ i ] , opts . samples_format , info = infotext ( n , i ) , p = p , suffix = " -before-face-restoration " )
2022-09-12 17:47:36 +03:00
2022-10-04 12:32:22 +03:00
devices . torch_gc ( )
2022-09-03 12:08:45 +03:00
2022-10-04 22:28:50 -03:00
x_sample = modules . face_restoration . restore_faces ( x_sample )
devices . torch_gc ( )
2022-09-28 22:14:13 -03:00
2022-09-03 12:08:45 +03:00
image = Image . fromarray ( x_sample )
2022-10-23 22:44:46 +03:00
2022-09-13 12:51:57 +03:00
if p . color_corrections is not None and i < len ( p . color_corrections ) :
2022-09-22 20:54:50 +10:00
if opts . save and not p . do_not_save_samples and opts . save_images_before_color_correction :
2022-10-24 09:15:26 +03:00
image_without_cc = apply_overlay ( image , p . paste_to , i , p . overlay_images )
2022-10-23 22:38:42 +03:00
images . save_image ( image_without_cc , p . outpath_samples , " " , seeds [ i ] , prompts [ i ] , opts . samples_format , info = infotext ( n , i ) , p = p , suffix = " -before-color-correction " )
2022-09-13 12:51:57 +03:00
image = apply_color_correction ( p . color_corrections [ i ] , image )
2022-09-12 17:47:36 +03:00
2022-10-24 09:15:26 +03:00
image = apply_overlay ( image , p . paste_to , i , p . overlay_images )
2022-09-03 12:08:45 +03:00
if opts . samples_save and not p . do_not_save_samples :
2022-09-12 15:41:30 +03:00
images . save_image ( image , p . outpath_samples , " " , seeds [ i ] , prompts [ i ] , opts . samples_format , info = infotext ( n , i ) , p = p )
2022-09-03 12:08:45 +03:00
2022-10-06 20:27:50 +03:00
text = infotext ( n , i )
infotexts . append ( text )
2022-10-09 13:10:15 +03:00
if opts . enable_pnginfo :
image . info [ " parameters " ] = text
2022-09-03 12:08:45 +03:00
output_images . append ( image )
2022-10-04 22:28:50 -03:00
del x_samples_ddim
2022-09-06 02:09:01 +03:00
2022-10-04 22:28:50 -03:00
devices . torch_gc ( )
2022-09-28 22:14:13 -03:00
2022-10-04 22:28:50 -03:00
state . nextjob ( )
2022-09-28 22:14:13 -03:00
2022-09-17 18:18:30 -04:00
p . color_corrections = None
2022-09-19 09:02:10 +03:00
index_of_first_image = 0
2022-09-03 12:08:45 +03:00
unwanted_grid_because_of_img_count = len ( output_images ) < 2 and opts . grid_only_if_multiple
2022-09-14 10:34:44 +03:00
if ( opts . return_grid or opts . grid_save ) and not p . do_not_save_grid and not unwanted_grid_because_of_img_count :
2022-09-03 17:21:15 +03:00
grid = images . image_grid ( output_images , p . batch_size )
2022-09-03 12:08:45 +03:00
2022-09-14 10:34:44 +03:00
if opts . return_grid :
2022-10-06 20:27:50 +03:00
text = infotext ( )
infotexts . insert ( 0 , text )
2022-10-09 13:10:15 +03:00
if opts . enable_pnginfo :
grid . info [ " parameters " ] = text
2022-09-03 12:08:45 +03:00
output_images . insert ( 0 , grid )
2022-09-19 09:02:10 +03:00
index_of_first_image = 1
2022-09-03 12:08:45 +03:00
if opts . grid_save :
2022-10-22 12:23:45 +03:00
images . save_image ( grid , p . outpath_grids , " grid " , p . all_seeds [ 0 ] , p . all_prompts [ 0 ] , opts . grid_format , info = infotext ( ) , short_filename = not opts . grid_extended_filename , p = p , grid = True )
2022-09-03 12:08:45 +03:00
2022-09-11 23:24:24 +03:00
devices . torch_gc ( )
2022-10-29 22:20:02 +03:00
2022-11-19 13:23:25 +03:00
res = Processed ( p , output_images , p . all_seeds [ 0 ] , infotext ( ) + " " . join ( [ " \n \n " + x for x in comments ] ) , subseed = p . all_subseeds [ 0 ] , index_of_first_image = index_of_first_image , infotexts = infotexts )
2022-10-29 22:20:02 +03:00
if p . scripts is not None :
p . scripts . postprocess ( p , res )
return res
2022-09-03 12:08:45 +03:00
class StableDiffusionProcessingTxt2Img ( StableDiffusionProcessing ) :
sampler = None
2022-09-19 16:42:56 +03:00
2022-10-17 19:10:36 +00:00
def __init__ ( self , enable_hr : bool = False , denoising_strength : float = 0.75 , firstphase_width : int = 0 , firstphase_height : int = 0 , * * kwargs ) :
2022-09-19 16:42:56 +03:00
super ( ) . __init__ ( * * kwargs )
self . enable_hr = enable_hr
self . denoising_strength = denoising_strength
2022-10-13 20:04:22 +01:00
self . firstphase_width = firstphase_width
self . firstphase_height = firstphase_height
2022-10-14 23:19:05 +03:00
self . truncate_x = 0
self . truncate_y = 0
2022-09-19 16:42:56 +03:00
def init ( self , all_prompts , all_seeds , all_subseeds ) :
if self . enable_hr :
if state . job_count == - 1 :
state . job_count = self . n_iter * 2
else :
state . job_count = state . job_count * 2
2022-10-22 20:09:37 +03:00
self . extra_generation_params [ " First pass size " ] = f " { self . firstphase_width } x { self . firstphase_height } "
2022-10-14 23:19:05 +03:00
if self . firstphase_width == 0 or self . firstphase_height == 0 :
desired_pixel_count = 512 * 512
actual_pixel_count = self . width * self . height
scale = math . sqrt ( desired_pixel_count / actual_pixel_count )
self . firstphase_width = math . ceil ( scale * self . width / 64 ) * 64
self . firstphase_height = math . ceil ( scale * self . height / 64 ) * 64
firstphase_width_truncated = int ( scale * self . width )
firstphase_height_truncated = int ( scale * self . height )
else :
width_ratio = self . width / self . firstphase_width
height_ratio = self . height / self . firstphase_height
if width_ratio > height_ratio :
firstphase_width_truncated = self . firstphase_width
firstphase_height_truncated = self . firstphase_width * self . height / self . width
else :
firstphase_width_truncated = self . firstphase_height * self . width / self . height
firstphase_height_truncated = self . firstphase_height
self . truncate_x = int ( self . firstphase_width - firstphase_width_truncated ) / / opt_f
self . truncate_y = int ( self . firstphase_height - firstphase_height_truncated ) / / opt_f
2022-11-02 12:45:03 +03:00
def sample ( self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , prompts ) :
2022-11-19 12:01:51 +03:00
self . sampler = sd_samplers . create_sampler ( self . sampler_name , self . sd_model )
2022-10-19 15:09:43 -07:00
if not self . enable_hr :
x = create_random_tensors ( [ opt_C , self . height / / opt_f , self . width / / opt_f ] , seeds = seeds , subseeds = subseeds , subseed_strength = self . subseed_strength , seed_resize_from_h = self . seed_resize_from_h , seed_resize_from_w = self . seed_resize_from_w , p = self )
2022-10-27 11:27:59 -07:00
samples = self . sampler . sample ( self , x , conditioning , unconditional_conditioning , image_conditioning = self . txt2img_image_conditioning ( x ) )
2022-09-19 16:42:56 +03:00
return samples
x = create_random_tensors ( [ opt_C , self . firstphase_height / / opt_f , self . firstphase_width / / opt_f ] , seeds = seeds , subseeds = subseeds , subseed_strength = self . subseed_strength , seed_resize_from_h = self . seed_resize_from_h , seed_resize_from_w = self . seed_resize_from_w , p = self )
2022-10-27 11:27:59 -07:00
samples = self . sampler . sample ( self , x , conditioning , unconditional_conditioning , image_conditioning = self . txt2img_image_conditioning ( x , self . firstphase_width , self . firstphase_height ) )
2022-09-19 16:42:56 +03:00
2022-10-14 23:19:05 +03:00
samples = samples [ : , : , self . truncate_y / / 2 : samples . shape [ 2 ] - self . truncate_y / / 2 , self . truncate_x / / 2 : samples . shape [ 3 ] - self . truncate_x / / 2 ]
2022-09-19 16:42:56 +03:00
2022-11-02 12:45:03 +03:00
""" saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images """
def save_intermediate ( image , index ) :
if not opts . save or self . do_not_save_samples or not opts . save_images_before_highres_fix :
return
if not isinstance ( image , Image . Image ) :
image = sd_samplers . sample_to_image ( image , index )
images . save_image ( image , self . outpath_samples , " " , seeds [ index ] , prompts [ index ] , opts . samples_format , suffix = " -before-highres-fix " )
2022-10-15 13:23:12 +03:00
if opts . use_scale_latent_for_hires_fix :
2022-11-04 10:45:34 +03:00
for i in range ( samples . shape [ 0 ] ) :
save_intermediate ( samples , i )
2022-10-15 13:23:12 +03:00
samples = torch . nn . functional . interpolate ( samples , size = ( self . height / / opt_f , self . width / / opt_f ) , mode = " bilinear " )
2022-11-04 10:45:34 +03:00
2022-10-29 10:35:51 -07:00
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
if getattr ( self , " inpainting_mask_weight " , shared . opts . inpainting_mask_weight ) < 1.0 :
image_conditioning = self . img2img_image_conditioning ( decode_first_stage ( self . sd_model , samples ) , samples )
else :
image_conditioning = self . txt2img_image_conditioning ( samples )
2022-09-19 16:42:56 +03:00
else :
2022-10-15 13:23:12 +03:00
decoded_samples = decode_first_stage ( self . sd_model , samples )
2022-10-14 17:03:03 +03:00
lowres_samples = torch . clamp ( ( decoded_samples + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
2022-09-20 19:32:26 +03:00
2022-10-14 17:03:03 +03:00
batch_images = [ ]
for i , x_sample in enumerate ( lowres_samples ) :
x_sample = 255. * np . moveaxis ( x_sample . cpu ( ) . numpy ( ) , 0 , 2 )
x_sample = x_sample . astype ( np . uint8 )
image = Image . fromarray ( x_sample )
2022-11-02 12:45:03 +03:00
save_intermediate ( image , i )
2022-10-14 17:03:03 +03:00
image = images . resize_image ( 0 , image , self . width , self . height )
image = np . array ( image ) . astype ( np . float32 ) / 255.0
image = np . moveaxis ( image , 2 , 0 )
batch_images . append ( image )
decoded_samples = torch . from_numpy ( np . array ( batch_images ) )
decoded_samples = decoded_samples . to ( shared . device )
decoded_samples = 2. * decoded_samples - 1.
2022-10-15 13:23:12 +03:00
samples = self . sd_model . get_first_stage_encoding ( self . sd_model . encode_first_stage ( decoded_samples ) )
2022-09-19 16:42:56 +03:00
2022-10-29 10:35:51 -07:00
image_conditioning = self . img2img_image_conditioning ( decoded_samples , samples )
2022-10-29 10:02:56 -07:00
2022-09-19 16:42:56 +03:00
shared . state . nextjob ( )
2022-09-03 12:08:45 +03:00
2022-11-19 12:01:51 +03:00
self . sampler = sd_samplers . create_sampler ( self . sampler_name , self . sd_model )
2022-10-06 14:12:52 +03:00
2022-09-19 16:42:56 +03:00
noise = create_random_tensors ( samples . shape [ 1 : ] , seeds = seeds , subseeds = subseeds , subseed_strength = subseed_strength , seed_resize_from_h = self . seed_resize_from_h , seed_resize_from_w = self . seed_resize_from_w , p = self )
2022-09-23 20:18:34 -05:00
# GC now before running the next img2img to prevent running out of memory
x = None
devices . torch_gc ( )
2022-10-04 09:49:51 -06:00
2022-10-27 11:27:59 -07:00
samples = self . sampler . sample_img2img ( self , samples , noise , conditioning , unconditional_conditioning , steps = self . steps , image_conditioning = image_conditioning )
2022-09-19 16:42:56 +03:00
return samples
2022-09-03 12:08:45 +03:00
class StableDiffusionProcessingImg2Img ( StableDiffusionProcessing ) :
sampler = None
2022-10-22 15:42:00 -04:00
def __init__ ( self , init_images : list = None , resize_mode : int = 0 , denoising_strength : float = 0.75 , mask : Any = None , mask_blur : int = 4 , inpainting_fill : int = 0 , inpaint_full_res : bool = True , inpaint_full_res_padding : int = 0 , inpainting_mask_invert : int = 0 , * * kwargs ) :
2022-09-03 12:08:45 +03:00
super ( ) . __init__ ( * * kwargs )
self . init_images = init_images
self . resize_mode : int = resize_mode
self . denoising_strength : float = denoising_strength
self . init_latent = None
self . image_mask = mask
2022-09-04 01:29:43 +03:00
self . latent_mask = None
2022-09-03 12:08:45 +03:00
self . mask_for_overlay = None
self . mask_blur = mask_blur
self . inpainting_fill = inpainting_fill
self . inpaint_full_res = inpaint_full_res
2022-09-22 12:11:48 +03:00
self . inpaint_full_res_padding = inpaint_full_res_padding
2022-09-03 21:02:38 +03:00
self . inpainting_mask_invert = inpainting_mask_invert
2022-09-03 12:08:45 +03:00
self . mask = None
self . nmask = None
2022-10-21 09:00:39 +03:00
self . image_conditioning = None
2022-09-03 12:08:45 +03:00
2022-09-19 16:42:56 +03:00
def init ( self , all_prompts , all_seeds , all_subseeds ) :
2022-11-19 12:01:51 +03:00
self . sampler = sd_samplers . create_sampler ( self . sampler_name , self . sd_model )
2022-09-03 12:08:45 +03:00
crop_region = None
2022-11-19 13:47:37 +03:00
image_mask = self . image_mask
2022-09-03 21:02:38 +03:00
2022-11-19 13:47:37 +03:00
if image_mask is not None :
image_mask = image_mask . convert ( ' L ' )
2022-09-03 21:02:38 +03:00
2022-11-19 13:47:37 +03:00
if self . inpainting_mask_invert :
image_mask = ImageOps . invert ( image_mask )
2022-09-04 01:29:43 +03:00
2022-09-03 12:08:45 +03:00
if self . mask_blur > 0 :
2022-11-19 13:47:37 +03:00
image_mask = image_mask . filter ( ImageFilter . GaussianBlur ( self . mask_blur ) )
2022-09-03 12:08:45 +03:00
if self . inpaint_full_res :
2022-11-19 13:47:37 +03:00
self . mask_for_overlay = image_mask
mask = image_mask . convert ( ' L ' )
2022-09-22 12:11:48 +03:00
crop_region = masking . get_crop_region ( np . array ( mask ) , self . inpaint_full_res_padding )
2022-09-18 10:49:00 +03:00
crop_region = masking . expand_crop_region ( crop_region , self . width , self . height , mask . width , mask . height )
2022-09-03 12:08:45 +03:00
x1 , y1 , x2 , y2 = crop_region
mask = mask . crop ( crop_region )
2022-11-19 13:47:37 +03:00
image_mask = images . resize_image ( 2 , mask , self . width , self . height )
2022-09-03 12:08:45 +03:00
self . paste_to = ( x1 , y1 , x2 - x1 , y2 - y1 )
else :
2022-11-19 13:47:37 +03:00
image_mask = images . resize_image ( self . resize_mode , image_mask , self . width , self . height )
np_mask = np . array ( image_mask )
2022-09-14 00:14:40 +10:00
np_mask = np . clip ( ( np_mask . astype ( np . float32 ) ) * 2 , 0 , 255 ) . astype ( np . uint8 )
2022-09-07 00:58:01 +03:00
self . mask_for_overlay = Image . fromarray ( np_mask )
2022-09-03 12:08:45 +03:00
self . overlay_images = [ ]
2022-11-19 13:47:37 +03:00
latent_mask = self . latent_mask if self . latent_mask is not None else image_mask
2022-09-07 17:00:51 +03:00
2022-09-16 08:33:47 +03:00
add_color_corrections = opts . img2img_color_correction and self . color_corrections is None
if add_color_corrections :
self . color_corrections = [ ]
2022-09-03 12:08:45 +03:00
imgs = [ ]
for img in self . init_images :
image = img . convert ( " RGB " )
if crop_region is None :
image = images . resize_image ( self . resize_mode , image , self . width , self . height )
2022-11-19 13:47:37 +03:00
if image_mask is not None :
2022-09-03 12:08:45 +03:00
image_masked = Image . new ( ' RGBa ' , ( image . width , image . height ) )
image_masked . paste ( image . convert ( " RGBA " ) . convert ( " RGBa " ) , mask = ImageOps . invert ( self . mask_for_overlay . convert ( ' L ' ) ) )
self . overlay_images . append ( image_masked . convert ( ' RGBA ' ) )
if crop_region is not None :
image = image . crop ( crop_region )
image = images . resize_image ( 2 , image , self . width , self . height )
2022-11-19 13:47:37 +03:00
if image_mask is not None :
2022-09-08 10:03:21 +03:00
if self . inpainting_fill != 1 :
2022-09-18 10:49:00 +03:00
image = masking . fill ( image , latent_mask )
2022-09-08 10:03:21 +03:00
2022-09-16 08:33:47 +03:00
if add_color_corrections :
2022-09-13 12:51:57 +03:00
self . color_corrections . append ( setup_color_correction ( image ) )
2022-09-03 12:08:45 +03:00
image = np . array ( image ) . astype ( np . float32 ) / 255.0
image = np . moveaxis ( image , 2 , 0 )
imgs . append ( image )
if len ( imgs ) == 1 :
batch_images = np . expand_dims ( imgs [ 0 ] , axis = 0 ) . repeat ( self . batch_size , axis = 0 )
if self . overlay_images is not None :
self . overlay_images = self . overlay_images * self . batch_size
2022-10-22 22:06:54 +03:00
if self . color_corrections is not None and len ( self . color_corrections ) == 1 :
self . color_corrections = self . color_corrections * self . batch_size
2022-09-03 12:08:45 +03:00
elif len ( imgs ) < = self . batch_size :
self . batch_size = len ( imgs )
batch_images = np . array ( imgs )
else :
raise RuntimeError ( f " bad number of images passed: { len ( imgs ) } ; expecting { self . batch_size } or less " )
image = torch . from_numpy ( batch_images )
image = 2. * image - 1.
image = image . to ( shared . device )
self . init_latent = self . sd_model . get_first_stage_encoding ( self . sd_model . encode_first_stage ( image ) )
2022-11-19 13:47:37 +03:00
if image_mask is not None :
2022-09-07 17:00:51 +03:00
init_mask = latent_mask
2022-09-04 01:29:43 +03:00
latmask = init_mask . convert ( ' RGB ' ) . resize ( ( self . init_latent . shape [ 3 ] , self . init_latent . shape [ 2 ] ) )
2022-09-12 20:09:32 +03:00
latmask = np . moveaxis ( np . array ( latmask , dtype = np . float32 ) , 2 , 0 ) / 255
2022-09-03 12:08:45 +03:00
latmask = latmask [ 0 ]
2022-09-07 17:00:51 +03:00
latmask = np . around ( latmask )
2022-09-03 12:08:45 +03:00
latmask = np . tile ( latmask [ None ] , ( 4 , 1 , 1 ) )
self . mask = torch . asarray ( 1.0 - latmask ) . to ( shared . device ) . type ( self . sd_model . dtype )
self . nmask = torch . asarray ( latmask ) . to ( shared . device ) . type ( self . sd_model . dtype )
2022-09-19 16:42:56 +03:00
# this needs to be fixed to be done in sample() using actual seeds for batches
2022-09-03 12:08:45 +03:00
if self . inpainting_fill == 2 :
2022-09-19 16:42:56 +03:00
self . init_latent = self . init_latent * self . mask + create_random_tensors ( self . init_latent . shape [ 1 : ] , all_seeds [ 0 : self . init_latent . shape [ 0 ] ] ) * self . nmask
2022-09-03 12:08:45 +03:00
elif self . inpainting_fill == 3 :
self . init_latent = self . init_latent * self . mask
2022-11-19 13:47:37 +03:00
self . image_conditioning = self . img2img_image_conditioning ( image , self . init_latent , image_mask )
2022-10-19 15:09:43 -07:00
2022-11-02 12:45:03 +03:00
def sample ( self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , prompts ) :
2022-09-19 16:42:56 +03:00
x = create_random_tensors ( [ opt_C , self . height / / opt_f , self . width / / opt_f ] , seeds = seeds , subseeds = subseeds , subseed_strength = self . subseed_strength , seed_resize_from_h = self . seed_resize_from_h , seed_resize_from_w = self . seed_resize_from_w , p = self )
2022-10-19 13:47:45 -07:00
samples = self . sampler . sample_img2img ( self , self . init_latent , x , conditioning , unconditional_conditioning , image_conditioning = self . image_conditioning )
2022-09-03 12:08:45 +03:00
if self . mask is not None :
samples = samples * self . nmask + self . init_latent * self . mask
2022-09-28 22:14:13 -03:00
del x
devices . torch_gc ( )
2022-11-02 12:45:03 +03:00
return samples