2022-10-29 08:42:34 +03:00
import base64
import io
2023-01-02 19:42:10 +03:00
import math
2022-10-08 22:57:19 -05:00
import os
2022-09-23 22:49:21 +03:00
import re
2022-11-27 23:04:42 +03:00
from pathlib import Path
2022-09-23 22:49:21 +03:00
import gradio as gr
2022-10-08 22:57:19 -05:00
from modules . shared import script_path
2023-01-03 14:18:48 +03:00
from modules import shared , ui_tempdir
2022-10-27 13:36:11 +08:00
import tempfile
2022-10-29 10:56:19 +03:00
from PIL import Image
2022-09-23 22:49:21 +03:00
2022-10-21 16:10:51 +03:00
re_param_code = r ' \ s*([ \ w ]+): \ s*( " (?: \\ | \ " |[^ \ " ])+ " |[^,]*)(?:,|$) '
2022-09-25 09:25:28 +03:00
re_param = re . compile ( re_param_code )
re_params = re . compile ( r " ^(?: " + re_param_code + " ) { 3,}$ " )
2022-09-23 22:49:21 +03:00
re_imagesize = re . compile ( r " ^( \ d+)x( \ d+)$ " )
2022-12-13 14:25:16 -08:00
re_hypernet_hash = re . compile ( " \ (([0-9a-f]+) \ )$ " )
2022-09-25 09:25:28 +03:00
type_of_gr_update = type ( gr . update ( ) )
2022-10-27 13:36:11 +08:00
paste_fields = { }
bind_list = [ ]
2022-09-23 22:49:21 +03:00
2022-10-29 08:42:34 +03:00
2022-10-31 17:36:45 +03:00
def reset ( ) :
paste_fields . clear ( )
bind_list . clear ( )
2022-10-21 16:10:51 +03:00
def quote ( text ) :
if ' , ' not in str ( text ) :
return text
text = str ( text )
text = text . replace ( ' \\ ' , ' \\ \\ ' )
text = text . replace ( ' " ' , ' \\ " ' )
return f ' " { text } " '
2022-10-29 08:42:34 +03:00
2022-10-27 13:36:11 +08:00
def image_from_url_text ( filedata ) :
2023-01-03 21:49:24 +03:00
if type ( filedata ) == list and len ( filedata ) > 0 and type ( filedata [ 0 ] ) == dict and filedata [ 0 ] . get ( " is_file " , False ) :
filedata = filedata [ 0 ]
if type ( filedata ) == dict and filedata . get ( " is_file " , False ) :
2022-10-27 13:36:11 +08:00
filename = filedata [ " name " ]
2023-01-03 14:18:48 +03:00
is_in_right_dir = ui_tempdir . check_tmp_file ( shared . demo , filename )
2022-11-27 23:04:42 +03:00
assert is_in_right_dir , ' trying to open image file outside of allowed directories '
2022-10-27 13:36:11 +08:00
return Image . open ( filename )
if type ( filedata ) == list :
if len ( filedata ) == 0 :
return None
filedata = filedata [ 0 ]
if filedata . startswith ( " data:image/png;base64, " ) :
filedata = filedata [ len ( " data:image/png;base64, " ) : ]
filedata = base64 . decodebytes ( filedata . encode ( ' utf-8 ' ) )
image = Image . open ( io . BytesIO ( filedata ) )
return image
2022-10-29 08:42:34 +03:00
2022-10-27 13:36:11 +08:00
def add_paste_fields ( tabname , init_img , fields ) :
2022-10-29 09:01:04 +03:00
paste_fields [ tabname ] = { " init_img " : init_img , " fields " : fields }
# backwards compatibility for existing extensions
import modules . ui
if tabname == ' txt2img ' :
modules . ui . txt2img_paste_fields = fields
elif tabname == ' img2img ' :
modules . ui . img2img_paste_fields = fields
2022-10-27 13:36:11 +08:00
2022-10-29 08:42:34 +03:00
2022-10-29 10:56:19 +03:00
def integrate_settings_paste_fields ( component_dict ) :
from modules import ui
settings_map = {
' sd_hypernetwork ' : ' Hypernet ' ,
2022-10-30 08:48:53 +03:00
' sd_hypernetwork_strength ' : ' Hypernet strength ' ,
2022-10-29 10:56:19 +03:00
' CLIP_stop_at_last_layers ' : ' Clip skip ' ,
2022-11-19 12:47:52 +03:00
' inpainting_mask_weight ' : ' Conditional mask weight ' ,
2022-10-29 10:56:19 +03:00
' sd_model_checkpoint ' : ' Model hash ' ,
2022-11-27 16:28:32 +03:00
' eta_noise_seed_delta ' : ' ENSD ' ,
2022-12-10 09:51:26 +03:00
' initial_noise_multiplier ' : ' Noise multiplier ' ,
2022-10-29 10:56:19 +03:00
}
settings_paste_fields = [
( component_dict [ k ] , lambda d , k = k , v = v : ui . apply_setting ( k , d . get ( v , None ) ) )
for k , v in settings_map . items ( )
]
for tabname , info in paste_fields . items ( ) :
if info [ " fields " ] is not None :
info [ " fields " ] + = settings_paste_fields
2022-10-27 13:36:11 +08:00
def create_buttons ( tabs_list ) :
buttons = { }
for tab in tabs_list :
2023-01-01 14:51:12 +01:00
buttons [ tab ] = gr . Button ( f " Send to { tab } " , elem_id = f " { tab } _tab " )
2022-10-27 13:36:11 +08:00
return buttons
2022-10-29 08:42:34 +03:00
2022-10-29 09:01:04 +03:00
#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
2022-10-27 13:36:11 +08:00
def bind_buttons ( buttons , send_image , send_generate_info ) :
bind_list . append ( [ buttons , send_image , send_generate_info ] )
2022-10-29 08:42:34 +03:00
2023-01-02 22:44:46 +03:00
def send_image_and_dimensions ( x ) :
if isinstance ( x , Image . Image ) :
img = x
else :
img = image_from_url_text ( x )
if shared . opts . send_size and isinstance ( img , Image . Image ) :
w = img . width
h = img . height
else :
w = gr . update ( )
h = gr . update ( )
return img , w , h
2022-10-27 13:36:11 +08:00
def run_bind ( ) :
2023-01-02 22:44:46 +03:00
for buttons , source_image_component , send_generate_info in bind_list :
2022-10-27 13:36:11 +08:00
for tab in buttons :
button = buttons [ tab ]
2023-01-02 22:44:46 +03:00
destination_image_component = paste_fields [ tab ] [ " init_img " ]
fields = paste_fields [ tab ] [ " fields " ]
destination_width_component = next ( iter ( [ field for field , name in fields if name == " Size-1 " ] if fields else [ ] ) , None )
destination_height_component = next ( iter ( [ field for field , name in fields if name == " Size-2 " ] if fields else [ ] ) , None )
if source_image_component and destination_image_component :
if isinstance ( source_image_component , gr . Gallery ) :
func = send_image_and_dimensions if destination_width_component else image_from_url_text
jsfunc = " extract_image_from_gallery "
2022-10-27 13:36:11 +08:00
else :
2023-01-02 22:44:46 +03:00
func = send_image_and_dimensions if destination_width_component else lambda x : x
jsfunc = None
button . click (
fn = func ,
_js = jsfunc ,
inputs = [ source_image_component ] ,
outputs = [ destination_image_component , destination_width_component , destination_height_component ] if destination_width_component else [ destination_image_component ] ,
)
2022-10-29 09:01:04 +03:00
2023-01-02 22:44:46 +03:00
if send_generate_info and fields is not None :
2022-10-29 09:01:04 +03:00
if send_generate_info in paste_fields :
2023-01-02 22:44:46 +03:00
paste_field_names = [ ' Prompt ' , ' Negative prompt ' , ' Steps ' , ' Face restoration ' ] + ( [ " Seed " ] if shared . opts . send_seed else [ ] )
2022-10-27 13:36:11 +08:00
button . click (
2022-10-29 10:56:19 +03:00
fn = lambda * x : x ,
inputs = [ field for field , name in paste_fields [ send_generate_info ] [ " fields " ] if name in paste_field_names ] ,
2023-01-02 22:44:46 +03:00
outputs = [ field for field , name in fields if name in paste_field_names ] ,
2022-10-27 13:36:11 +08:00
)
else :
2023-01-02 22:44:46 +03:00
connect_paste ( button , fields , send_generate_info )
2022-10-27 13:36:11 +08:00
button . click (
fn = None ,
_js = f " switch_to_ { tab } " ,
inputs = None ,
outputs = None ,
)
2022-10-29 08:42:34 +03:00
2022-12-13 14:25:16 -08:00
def find_hypernetwork_key ( hypernet_name , hypernet_hash = None ) :
""" Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
Example : an infotext provides " Hypernet: ke-ta " and " Hypernet hash: 1234abcd " . For the " Hypernet " config
parameter this means there should be an entry that looks like " ke-ta-10000(1234abcd) " to set it to .
2022-12-13 14:32:26 -08:00
If the infotext has no hash , then a hypernet with the same name will be selected instead .
2022-12-13 14:25:16 -08:00
"""
hypernet_name = hypernet_name . lower ( )
if hypernet_hash is not None :
# Try to match the hash in the name
for hypernet_key in shared . hypernetworks . keys ( ) :
result = re_hypernet_hash . search ( hypernet_key )
if result is not None and result [ 1 ] == hypernet_hash :
return hypernet_key
else :
# Fall back to a hypernet with the same name
for hypernet_key in shared . hypernetworks . keys ( ) :
if hypernet_key . lower ( ) . startswith ( hypernet_name ) :
return hypernet_key
return None
2023-01-02 19:42:10 +03:00
def restore_old_hires_fix_params ( res ) :
""" for infotexts that specify old First pass size parameter, convert it into
width , height , and hr scale """
firstpass_width = res . get ( ' First pass size-1 ' , None )
firstpass_height = res . get ( ' First pass size-2 ' , None )
if firstpass_width is None or firstpass_height is None :
return
firstpass_width , firstpass_height = int ( firstpass_width ) , int ( firstpass_height )
width = int ( res . get ( " Size-1 " , 512 ) )
height = int ( res . get ( " Size-2 " , 512 ) )
if firstpass_width == 0 or firstpass_height == 0 :
# old algorithm for auto-calculating first pass size
desired_pixel_count = 512 * 512
actual_pixel_count = width * height
scale = math . sqrt ( desired_pixel_count / actual_pixel_count )
firstpass_width = math . ceil ( scale * width / 64 ) * 64
firstpass_height = math . ceil ( scale * height / 64 ) * 64
hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height
res [ ' Size-1 ' ] = firstpass_width
res [ ' Size-2 ' ] = firstpass_height
res [ ' Hires upscale ' ] = hr_scale
2022-09-23 22:49:21 +03:00
def parse_generation_parameters ( x : str ) :
""" parses generation parameters string, the one you see in text field under the picture in UI:
` ` `
girl with an artist ' s beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
Negative prompt : ugly , fat , obese , chubby , ( ( ( deformed ) ) ) , [ blurry ] , bad anatomy , disfigured , poorly drawn face , mutation , mutated , ( extra_limb ) , ( ugly ) , ( poorly drawn hands ) , messy drawing
Steps : 20 , Sampler : Euler a , CFG scale : 7 , Seed : 965400086 , Size : 512 x512 , Model hash : 45 dee52b
` ` `
returns a dict with field values
"""
res = { }
prompt = " "
negative_prompt = " "
done_with_prompt = False
* lines , lastline = x . strip ( ) . split ( " \n " )
2022-09-25 09:25:28 +03:00
if not re_params . match ( lastline ) :
lines . append ( lastline )
lastline = ' '
2022-09-23 22:49:21 +03:00
for i , line in enumerate ( lines ) :
line = line . strip ( )
if line . startswith ( " Negative prompt: " ) :
done_with_prompt = True
line = line [ 16 : ] . strip ( )
if done_with_prompt :
2022-09-25 09:25:28 +03:00
negative_prompt + = ( " " if negative_prompt == " " else " \n " ) + line
2022-09-23 22:49:21 +03:00
else :
2022-09-25 09:25:28 +03:00
prompt + = ( " " if prompt == " " else " \n " ) + line
2022-09-23 22:49:21 +03:00
2022-10-19 14:22:03 -04:00
res [ " Prompt " ] = prompt
res [ " Negative prompt " ] = negative_prompt
2022-09-23 22:49:21 +03:00
for k , v in re_param . findall ( lastline ) :
m = re_imagesize . match ( v )
if m is not None :
res [ k + " -1 " ] = m . group ( 1 )
res [ k + " -2 " ] = m . group ( 2 )
else :
res [ k ] = v
2022-12-01 11:34:16 -08:00
# Missing CLIP skip means it was set to 1 (the default)
if " Clip skip " not in res :
res [ " Clip skip " ] = " 1 "
2022-12-13 14:30:54 -08:00
if " Hypernet strength " not in res :
res [ " Hypernet strength " ] = " 1 "
2022-12-13 14:25:16 -08:00
if " Hypernet " in res :
hypernet_name = res [ " Hypernet " ]
hypernet_hash = res . get ( " Hypernet hash " , None )
res [ " Hypernet " ] = find_hypernetwork_key ( hypernet_name , hypernet_hash )
2023-01-02 19:42:10 +03:00
restore_old_hires_fix_params ( res )
2022-09-23 22:49:21 +03:00
return res
2022-10-29 09:01:04 +03:00
def connect_paste ( button , paste_fields , input_comp , jsfunc = None ) :
2022-09-23 22:49:21 +03:00
def paste_func ( prompt ) :
2022-10-13 12:26:34 +03:00
if not prompt and not shared . cmd_opts . hide_ui_dir_config :
2022-10-08 22:57:19 -05:00
filename = os . path . join ( script_path , " params.txt " )
if os . path . exists ( filename ) :
with open ( filename , " r " , encoding = " utf8 " ) as file :
prompt = file . read ( )
2022-09-23 22:49:21 +03:00
params = parse_generation_parameters ( prompt )
res = [ ]
2022-09-25 09:25:28 +03:00
for output , key in paste_fields :
if callable ( key ) :
v = key ( params )
else :
v = params . get ( key , None )
2022-09-23 22:49:21 +03:00
if v is None :
res . append ( gr . update ( ) )
2022-09-25 09:25:28 +03:00
elif isinstance ( v , type_of_gr_update ) :
res . append ( v )
2022-09-23 22:49:21 +03:00
else :
try :
valtype = type ( output . value )
2022-10-21 16:10:51 +03:00
if valtype == bool and v == " False " :
val = False
else :
val = valtype ( v )
2022-09-23 22:49:21 +03:00
res . append ( gr . update ( value = val ) )
except Exception :
res . append ( gr . update ( ) )
return res
button . click (
fn = paste_func ,
2022-10-29 09:01:04 +03:00
_js = jsfunc ,
2022-09-23 22:49:21 +03:00
inputs = [ input_comp ] ,
2022-09-25 09:25:28 +03:00
outputs = [ x [ 0 ] for x in paste_fields ] ,
2022-09-23 22:49:21 +03:00
)
2022-10-27 13:36:11 +08:00