Integrate new bucket parameters in GUI

This commit is contained in:
bmaltais 2023-02-05 20:07:00 -05:00
parent 2486af9903
commit cbfc311687
17 changed files with 414 additions and 217 deletions

View File

@ -82,8 +82,12 @@ def save_configuration(
max_data_loader_n_workers, max_data_loader_n_workers,
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, keep_tokens, model_list,
keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -167,8 +171,12 @@ def open_configuration(
max_data_loader_n_workers, max_data_loader_n_workers,
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, keep_tokens, model_list,
keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -239,6 +247,9 @@ def train_model(
model_list, # Keep this. Yes, it is unused here but required given the common list used model_list, # Keep this. Yes, it is unused here but required given the common list used
keep_tokens, keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -402,6 +413,9 @@ def train_model(
use_8bit_adam=use_8bit_adam, use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens, keep_tokens=keep_tokens,
persistent_data_loader_workers=persistent_data_loader_workers, persistent_data_loader_workers=persistent_data_loader_workers,
bucket_no_upscale=bucket_no_upscale,
random_crop=random_crop,
bucket_reso_steps=bucket_reso_steps,
) )
print(run_cmd) print(run_cmd)
@ -610,6 +624,9 @@ def dreambooth_tab(
max_data_loader_n_workers, max_data_loader_n_workers,
keep_tokens, keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -675,6 +692,9 @@ def dreambooth_tab(
model_list, model_list,
keep_tokens, keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
] ]
button_open_config.click( button_open_config.click(

View File

@ -78,8 +78,12 @@ def save_configuration(
color_aug, color_aug,
model_list, model_list,
cache_latents, cache_latents,
use_latent_files, keep_tokens, use_latent_files,
keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -169,8 +173,12 @@ def open_config_file(
color_aug, color_aug,
model_list, model_list,
cache_latents, cache_latents,
use_latent_files, keep_tokens, use_latent_files,
keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -245,8 +253,12 @@ def train_model(
color_aug, color_aug,
model_list, # Keep this. Yes, it is unused here but required given the common list used model_list, # Keep this. Yes, it is unused here but required given the common list used
cache_latents, cache_latents,
use_latent_files, keep_tokens, use_latent_files,
keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
): ):
# create caption json file # create caption json file
if generate_caption_database: if generate_caption_database:
@ -295,7 +307,11 @@ def train_model(
subprocess.run(run_cmd) subprocess.run(run_cmd)
image_num = len( image_num = len(
[f for f in os.listdir(image_folder) if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.webp')] [
f
for f in os.listdir(image_folder)
if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.webp')
]
) )
print(f'image_num = {image_num}') print(f'image_num = {image_num}')
@ -386,6 +402,9 @@ def train_model(
use_8bit_adam=use_8bit_adam, use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens, keep_tokens=keep_tokens,
persistent_data_loader_workers=persistent_data_loader_workers, persistent_data_loader_workers=persistent_data_loader_workers,
bucket_no_upscale=bucket_no_upscale,
random_crop=random_crop,
bucket_reso_steps=bucket_reso_steps,
) )
print(run_cmd) print(run_cmd)
@ -592,6 +611,9 @@ def finetune_tab():
max_data_loader_n_workers, max_data_loader_n_workers,
keep_tokens, keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -653,6 +675,9 @@ def finetune_tab():
use_latent_files, use_latent_files,
keep_tokens, keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
] ]
button_run.click(train_model, inputs=settings_list) button_run.click(train_model, inputs=settings_list)

View File

@ -19,7 +19,7 @@ def UI(username, password):
print('Load CSS...') print('Load CSS...')
css += file.read() + '\n' css += file.read() + '\n'
interface = gr.Blocks(css=css, title="Kohya_ss GUI") interface = gr.Blocks(css=css, title='Kohya_ss GUI')
with interface: with interface:
with gr.Tab('Dreambooth'): with gr.Tab('Dreambooth'):

View File

@ -10,13 +10,15 @@ def caption_images(
overwrite_input, overwrite_input,
caption_file_ext, caption_file_ext,
prefix, prefix,
postfix, find, replace postfix,
find,
replace,
): ):
# Check for images_dir_input # Check for images_dir_input
if images_dir_input == '': if images_dir_input == '':
msgbox('Image folder is missing...') msgbox('Image folder is missing...')
return return
if caption_file_ext == '': if caption_file_ext == '':
msgbox('Please provide an extension for the caption files.') msgbox('Please provide an extension for the caption files.')
return return
@ -39,7 +41,7 @@ def caption_images(
subprocess.run(run_cmd) subprocess.run(run_cmd)
if overwrite_input: if overwrite_input:
if not prefix=='' or not postfix=='': if not prefix == '' or not postfix == '':
# Add prefix and postfix # Add prefix and postfix
add_pre_postfix( add_pre_postfix(
folder=images_dir_input, folder=images_dir_input,
@ -47,7 +49,7 @@ def caption_images(
prefix=prefix, prefix=prefix,
postfix=postfix, postfix=postfix,
) )
if not find=='': if not find == '':
find_replace( find_replace(
folder=images_dir_input, folder=images_dir_input,
caption_file_ext=caption_file_ext, caption_file_ext=caption_file_ext,
@ -134,6 +136,7 @@ def gradio_basic_caption_gui_tab():
caption_file_ext, caption_file_ext,
prefix, prefix,
postfix, postfix,
find, replace find,
replace,
], ],
) )

View File

@ -26,7 +26,7 @@ def caption_images(
if train_data_dir == '': if train_data_dir == '':
msgbox('Image folder is missing...') msgbox('Image folder is missing...')
return return
if caption_file_ext == '': if caption_file_ext == '':
msgbox('Please provide an extension for the caption files.') msgbox('Please provide an extension for the caption files.')
return return

View File

@ -9,6 +9,7 @@ refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾 save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄 document_symbol = '\U0001F4C4' # 📄
def get_dir_and_file(file_path): def get_dir_and_file(file_path):
dir_path, file_name = os.path.split(file_path) dir_path, file_name = os.path.split(file_path)
return (dir_path, file_name) return (dir_path, file_name)
@ -200,7 +201,7 @@ def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)] files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
for file in files: for file in files:
with open(os.path.join(folder, file), 'r', errors="ignore") as f: with open(os.path.join(folder, file), 'r', errors='ignore') as f:
content = f.read() content = f.read()
f.close f.close
content = content.replace(find, replace) content = content.replace(find, replace)
@ -304,7 +305,8 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
### ###
### Gradio common GUI section ### Gradio common GUI section
### ###
def gradio_config(): def gradio_config():
with gr.Accordion('Configuration file', open=False): with gr.Accordion('Configuration file', open=False):
with gr.Row(): with gr.Row():
@ -318,7 +320,13 @@ def gradio_config():
placeholder="type the configuration file path or use the 'Open' button above to select it...", placeholder="type the configuration file path or use the 'Open' button above to select it...",
interactive=True, interactive=True,
) )
return (button_open_config, button_save_config, button_save_as_config, config_file_name) return (
button_open_config,
button_save_config,
button_save_as_config,
config_file_name,
)
def gradio_source_model(): def gradio_source_model():
with gr.Tab('Source model'): with gr.Tab('Source model'):
@ -382,9 +390,20 @@ def gradio_source_model():
v_parameterization, v_parameterization,
], ],
) )
return (pretrained_model_name_or_path, v2, v_parameterization, save_model_as, model_list) return (
pretrained_model_name_or_path,
v2,
v_parameterization,
save_model_as,
model_list,
)
def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', lr_warmup_value='0'):
def gradio_training(
learning_rate_value='1e-6',
lr_scheduler_value='constant',
lr_warmup_value='0',
):
with gr.Row(): with gr.Row():
train_batch_size = gr.Slider( train_batch_size = gr.Slider(
minimum=1, minimum=1,
@ -394,9 +413,7 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l
step=1, step=1,
) )
epoch = gr.Textbox(label='Epoch', value=1) epoch = gr.Textbox(label='Epoch', value=1)
save_every_n_epochs = gr.Textbox( save_every_n_epochs = gr.Textbox(label='Save every N epochs', value=1)
label='Save every N epochs', value=1
)
caption_extension = gr.Textbox( caption_extension = gr.Textbox(
label='Caption Extension', label='Caption Extension',
placeholder='(Optional) Extension for caption files. default: .caption', placeholder='(Optional) Extension for caption files. default: .caption',
@ -429,7 +446,9 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l
) )
seed = gr.Textbox(label='Seed', value=1234) seed = gr.Textbox(label='Seed', value=1234)
with gr.Row(): with gr.Row():
learning_rate = gr.Textbox(label='Learning rate', value=learning_rate_value) learning_rate = gr.Textbox(
label='Learning rate', value=learning_rate_value
)
lr_scheduler = gr.Dropdown( lr_scheduler = gr.Dropdown(
label='LR Scheduler', label='LR Scheduler',
choices=[ choices=[
@ -442,7 +461,9 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l
], ],
value=lr_scheduler_value, value=lr_scheduler_value,
) )
lr_warmup = gr.Textbox(label='LR warmup (% of steps)', value=lr_warmup_value) lr_warmup = gr.Textbox(
label='LR warmup (% of steps)', value=lr_warmup_value
)
cache_latents = gr.Checkbox(label='Cache latent', value=True) cache_latents = gr.Checkbox(label='Cache latent', value=True)
return ( return (
learning_rate, learning_rate,
@ -459,50 +480,38 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l
cache_latents, cache_latents,
) )
def run_cmd_training(**kwargs): def run_cmd_training(**kwargs):
options = [ options = [
f' --learning_rate="{kwargs.get("learning_rate", "")}"' f' --learning_rate="{kwargs.get("learning_rate", "")}"'
if kwargs.get('learning_rate') if kwargs.get('learning_rate')
else '', else '',
f' --lr_scheduler="{kwargs.get("lr_scheduler", "")}"' f' --lr_scheduler="{kwargs.get("lr_scheduler", "")}"'
if kwargs.get('lr_scheduler') if kwargs.get('lr_scheduler')
else '', else '',
f' --lr_warmup_steps="{kwargs.get("lr_warmup_steps", "")}"' f' --lr_warmup_steps="{kwargs.get("lr_warmup_steps", "")}"'
if kwargs.get('lr_warmup_steps') if kwargs.get('lr_warmup_steps')
else '', else '',
f' --train_batch_size="{kwargs.get("train_batch_size", "")}"' f' --train_batch_size="{kwargs.get("train_batch_size", "")}"'
if kwargs.get('train_batch_size') if kwargs.get('train_batch_size')
else '', else '',
f' --max_train_steps="{kwargs.get("max_train_steps", "")}"' f' --max_train_steps="{kwargs.get("max_train_steps", "")}"'
if kwargs.get('max_train_steps') if kwargs.get('max_train_steps')
else '', else '',
f' --save_every_n_epochs="{kwargs.get("save_every_n_epochs", "")}"' f' --save_every_n_epochs="{kwargs.get("save_every_n_epochs", "")}"'
if kwargs.get('save_every_n_epochs') if kwargs.get('save_every_n_epochs')
else '', else '',
f' --mixed_precision="{kwargs.get("mixed_precision", "")}"' f' --mixed_precision="{kwargs.get("mixed_precision", "")}"'
if kwargs.get('mixed_precision') if kwargs.get('mixed_precision')
else '', else '',
f' --save_precision="{kwargs.get("save_precision", "")}"' f' --save_precision="{kwargs.get("save_precision", "")}"'
if kwargs.get('save_precision') if kwargs.get('save_precision')
else '', else '',
f' --seed="{kwargs.get("seed", "")}"' if kwargs.get('seed') else '',
f' --seed="{kwargs.get("seed", "")}"'
if kwargs.get('seed')
else '',
f' --caption_extension="{kwargs.get("caption_extension", "")}"' f' --caption_extension="{kwargs.get("caption_extension", "")}"'
if kwargs.get('caption_extension') if kwargs.get('caption_extension')
else '', else '',
' --cache_latents' if kwargs.get('cache_latents') else '', ' --cache_latents' if kwargs.get('cache_latents') else '',
] ]
run_cmd = ''.join(options) run_cmd = ''.join(options)
return run_cmd return run_cmd
@ -532,9 +541,7 @@ def gradio_advanced_training():
gradient_checkpointing = gr.Checkbox( gradient_checkpointing = gr.Checkbox(
label='Gradient checkpointing', value=False label='Gradient checkpointing', value=False
) )
shuffle_caption = gr.Checkbox( shuffle_caption = gr.Checkbox(label='Shuffle caption', value=False)
label='Shuffle caption', value=False
)
persistent_data_loader_workers = gr.Checkbox( persistent_data_loader_workers = gr.Checkbox(
label='Persistent data loader', value=False label='Persistent data loader', value=False
) )
@ -544,10 +551,18 @@ def gradio_advanced_training():
with gr.Row(): with gr.Row():
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
xformers = gr.Checkbox(label='Use xformers', value=True) xformers = gr.Checkbox(label='Use xformers', value=True)
color_aug = gr.Checkbox( color_aug = gr.Checkbox(label='Color augmentation', value=False)
label='Color augmentation', value=False
)
flip_aug = gr.Checkbox(label='Flip augmentation', value=False) flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
with gr.Row():
bucket_no_upscale = gr.Checkbox(
label="Don't upscale bucket resolution", value=True
)
random_crop = gr.Checkbox(
label='Random crop instead of center crop', value=False
)
bucket_reso_steps = gr.Number(
label='Bucket resolution steps', value=64
)
with gr.Row(): with gr.Row():
save_state = gr.Checkbox(label='Save training state', value=False) save_state = gr.Checkbox(label='Save training state', value=False)
resume = gr.Textbox( resume = gr.Textbox(
@ -581,55 +596,53 @@ def gradio_advanced_training():
max_data_loader_n_workers, max_data_loader_n_workers,
keep_tokens, keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
) )
def run_cmd_advanced_training(**kwargs): def run_cmd_advanced_training(**kwargs):
options = [ options = [
f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"' f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"'
if kwargs.get('max_train_epochs') if kwargs.get('max_train_epochs')
else '', else '',
f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"' f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"'
if kwargs.get('max_data_loader_n_workers') if kwargs.get('max_data_loader_n_workers')
else '', else '',
f' --max_token_length={kwargs.get("max_token_length", "")}' f' --max_token_length={kwargs.get("max_token_length", "")}'
if int(kwargs.get('max_token_length', 75)) > 75 if int(kwargs.get('max_token_length', 75)) > 75
else '', else '',
f' --clip_skip={kwargs.get("clip_skip", "")}' f' --clip_skip={kwargs.get("clip_skip", "")}'
if int(kwargs.get('clip_skip', 1)) > 1 if int(kwargs.get('clip_skip', 1)) > 1
else '', else '',
f' --resume="{kwargs.get("resume", "")}"' f' --resume="{kwargs.get("resume", "")}"'
if kwargs.get('resume') if kwargs.get('resume')
else '', else '',
f' --keep_tokens="{kwargs.get("keep_tokens", "")}"' f' --keep_tokens="{kwargs.get("keep_tokens", "")}"'
if int(kwargs.get('keep_tokens', 0)) > 0 if int(kwargs.get('keep_tokens', 0)) > 0
else '', else '',
f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}'
if int(kwargs.get('bucket_reso_steps', 64)) >= 1
else '',
' --save_state' if kwargs.get('save_state') else '', ' --save_state' if kwargs.get('save_state') else '',
' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '', ' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',
' --color_aug' if kwargs.get('color_aug') else '', ' --color_aug' if kwargs.get('color_aug') else '',
' --flip_aug' if kwargs.get('flip_aug') else '', ' --flip_aug' if kwargs.get('flip_aug') else '',
' --shuffle_caption' if kwargs.get('shuffle_caption') else '', ' --shuffle_caption' if kwargs.get('shuffle_caption') else '',
' --gradient_checkpointing'
' --gradient_checkpointing' if kwargs.get('gradient_checkpointing') else '', if kwargs.get('gradient_checkpointing')
else '',
' --full_fp16' if kwargs.get('full_fp16') else '', ' --full_fp16' if kwargs.get('full_fp16') else '',
' --xformers' if kwargs.get('xformers') else '', ' --xformers' if kwargs.get('xformers') else '',
' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '', ' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '',
' --persistent_data_loader_workers'
' --persistent_data_loader_workers' if kwargs.get('persistent_data_loader_workers') else '', if kwargs.get('persistent_data_loader_workers')
else '',
' --bucket_no_upscale' if kwargs.get('bucket_no_upscale') else '',
' --random_crop' if kwargs.get('random_crop') else '',
] ]
run_cmd = ''.join(options) run_cmd = ''.join(options)
return run_cmd return run_cmd

View File

@ -191,9 +191,7 @@ def gradio_dreambooth_folder_creation_tab(
util_training_dir_output, util_training_dir_output,
], ],
) )
button_copy_info_to_Folders_tab = gr.Button( button_copy_info_to_Folders_tab = gr.Button('Copy info to Folders Tab')
'Copy info to Folders Tab'
)
button_copy_info_to_Folders_tab.click( button_copy_info_to_Folders_tab.click(
copy_info_to_Folders_tab, copy_info_to_Folders_tab,
inputs=[util_training_dir_output], inputs=[util_training_dir_output],

View File

@ -2,7 +2,11 @@ import gradio as gr
from easygui import msgbox from easygui import msgbox
import subprocess import subprocess
import os import os
from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path from .common_gui import (
get_saveasfilename_path,
get_any_file_path,
get_file_path,
)
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '\U0001f504' # 🔄
@ -11,13 +15,18 @@ document_symbol = '\U0001F4C4' # 📄
def extract_lora( def extract_lora(
model_tuned, model_org, save_to, save_precision, dim, v2, model_tuned,
model_org,
save_to,
save_precision,
dim,
v2,
): ):
# Check for caption_text_input # Check for caption_text_input
if model_tuned == '': if model_tuned == '':
msgbox('Invalid finetuned model file') msgbox('Invalid finetuned model file')
return return
if model_org == '': if model_org == '':
msgbox('Invalid base model file') msgbox('Invalid base model file')
return return
@ -26,12 +35,14 @@ def extract_lora(
if not os.path.isfile(model_tuned): if not os.path.isfile(model_tuned):
msgbox('The provided finetuned model is not a file') msgbox('The provided finetuned model is not a file')
return return
if not os.path.isfile(model_org): if not os.path.isfile(model_org):
msgbox('The provided base model is not a file') msgbox('The provided base model is not a file')
return return
run_cmd = f'.\\venv\Scripts\python.exe "networks\extract_lora_from_models.py"' run_cmd = (
f'.\\venv\Scripts\python.exe "networks\extract_lora_from_models.py"'
)
run_cmd += f' --save_precision {save_precision}' run_cmd += f' --save_precision {save_precision}'
run_cmd += f' --save_to "{save_to}"' run_cmd += f' --save_to "{save_to}"'
run_cmd += f' --model_org "{model_org}"' run_cmd += f' --model_org "{model_org}"'
@ -60,7 +71,7 @@ def gradio_extract_lora_tab():
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
model_ext = gr.Textbox(value='*.ckpt *.safetensors', visible=False) model_ext = gr.Textbox(value='*.ckpt *.safetensors', visible=False)
model_ext_name = gr.Textbox(value='Model types', visible=False) model_ext_name = gr.Textbox(value='Model types', visible=False)
with gr.Row(): with gr.Row():
model_tuned = gr.Textbox( model_tuned = gr.Textbox(
label='Finetuned model', label='Finetuned model',
@ -75,7 +86,7 @@ def gradio_extract_lora_tab():
inputs=[model_tuned, model_ext, model_ext_name], inputs=[model_tuned, model_ext, model_ext_name],
outputs=model_tuned, outputs=model_tuned,
) )
model_org = gr.Textbox( model_org = gr.Textbox(
label='Stable Diffusion base model', label='Stable Diffusion base model',
placeholder='Stable Diffusion original model: ckpt or safetensors file', placeholder='Stable Diffusion original model: ckpt or safetensors file',
@ -99,7 +110,9 @@ def gradio_extract_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_save_to.click( button_save_to.click(
get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to get_saveasfilename_path,
inputs=[save_to, lora_ext, lora_ext_name],
outputs=save_to,
) )
save_precision = gr.Dropdown( save_precision = gr.Dropdown(
label='Save precison', label='Save precison',
@ -122,6 +135,5 @@ def gradio_extract_lora_tab():
extract_button.click( extract_button.click(
extract_lora, extract_lora,
inputs=[model_tuned, model_org, save_to, save_precision, dim, v2 inputs=[model_tuned, model_org, save_to, save_precision, dim, v2],
],
) )

View File

@ -15,11 +15,11 @@ def caption_images(
prefix, prefix,
postfix, postfix,
): ):
# Check for images_dir_input # Check for images_dir_input
if train_data_dir == '': if train_data_dir == '':
msgbox('Image folder is missing...') msgbox('Image folder is missing...')
return return
if caption_ext == '': if caption_ext == '':
msgbox('Please provide an extension for the caption files.') msgbox('Please provide an extension for the caption files.')
return return
@ -29,7 +29,9 @@ def caption_images(
if not model_id == '': if not model_id == '':
run_cmd += f' --model_id="{model_id}"' run_cmd += f' --model_id="{model_id}"'
run_cmd += f' --batch_size="{int(batch_size)}"' run_cmd += f' --batch_size="{int(batch_size)}"'
run_cmd += f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"' run_cmd += (
f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"'
)
run_cmd += f' --max_length="{int(max_length)}"' run_cmd += f' --max_length="{int(max_length)}"'
if caption_ext != '': if caption_ext != '':
run_cmd += f' --caption_extension="{caption_ext}"' run_cmd += f' --caption_extension="{caption_ext}"'
@ -105,8 +107,9 @@ def gradio_git_caption_gui_tab():
value=75, label='Max length', interactive=True value=75, label='Max length', interactive=True
) )
model_id = gr.Textbox( model_id = gr.Textbox(
label="Model", label='Model',
placeholder="(Optional) model id for GIT in Hugging Face", interactive=True placeholder='(Optional) model id for GIT in Hugging Face',
interactive=True,
) )
caption_button = gr.Button('Caption images') caption_button = gr.Button('Caption images')

View File

@ -2,7 +2,11 @@ import gradio as gr
from easygui import msgbox from easygui import msgbox
import subprocess import subprocess
import os import os
from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path from .common_gui import (
get_saveasfilename_path,
get_any_file_path,
get_file_path,
)
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '\U0001f504' # 🔄
@ -11,13 +15,18 @@ document_symbol = '\U0001F4C4' # 📄
def merge_lora( def merge_lora(
lora_a_model, lora_b_model, ratio, save_to, precision, save_precision, lora_a_model,
lora_b_model,
ratio,
save_to,
precision,
save_precision,
): ):
# Check for caption_text_input # Check for caption_text_input
if lora_a_model == '': if lora_a_model == '':
msgbox('Invalid model A file') msgbox('Invalid model A file')
return return
if lora_b_model == '': if lora_b_model == '':
msgbox('Invalid model B file') msgbox('Invalid model B file')
return return
@ -26,7 +35,7 @@ def merge_lora(
if not os.path.isfile(lora_a_model): if not os.path.isfile(lora_a_model):
msgbox('The provided model A is not a file') msgbox('The provided model A is not a file')
return return
if not os.path.isfile(lora_b_model): if not os.path.isfile(lora_b_model):
msgbox('The provided model B is not a file') msgbox('The provided model B is not a file')
return return
@ -54,13 +63,11 @@ def merge_lora(
def gradio_merge_lora_tab(): def gradio_merge_lora_tab():
with gr.Tab('Merge LoRA'): with gr.Tab('Merge LoRA'):
gr.Markdown( gr.Markdown('This utility can merge two LoRA networks together.')
'This utility can merge two LoRA networks together.'
)
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False) lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
with gr.Row(): with gr.Row():
lora_a_model = gr.Textbox( lora_a_model = gr.Textbox(
label='LoRA model "A"', label='LoRA model "A"',
@ -75,7 +82,7 @@ def gradio_merge_lora_tab():
inputs=[lora_a_model, lora_ext, lora_ext_name], inputs=[lora_a_model, lora_ext, lora_ext_name],
outputs=lora_a_model, outputs=lora_a_model,
) )
lora_b_model = gr.Textbox( lora_b_model = gr.Textbox(
label='LoRA model "B"', label='LoRA model "B"',
placeholder='Path to the LoRA B model', placeholder='Path to the LoRA B model',
@ -90,9 +97,15 @@ def gradio_merge_lora_tab():
outputs=lora_b_model, outputs=lora_b_model,
) )
with gr.Row(): with gr.Row():
ratio = gr.Slider(label="Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B", minimum=0, maximum=1, step=0.01, value=0.5, ratio = gr.Slider(
interactive=True,) label='Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B',
minimum=0,
maximum=1,
step=0.01,
value=0.5,
interactive=True,
)
with gr.Row(): with gr.Row():
save_to = gr.Textbox( save_to = gr.Textbox(
label='Save to', label='Save to',
@ -103,7 +116,9 @@ def gradio_merge_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_save_to.click( button_save_to.click(
get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to get_saveasfilename_path,
inputs=[save_to, lora_ext, lora_ext_name],
outputs=save_to,
) )
precision = gr.Dropdown( precision = gr.Dropdown(
label='Merge precison', label='Merge precison',
@ -122,6 +137,12 @@ def gradio_merge_lora_tab():
convert_button.click( convert_button.click(
merge_lora, merge_lora,
inputs=[lora_a_model, lora_b_model, ratio, save_to, precision, save_precision, inputs=[
lora_a_model,
lora_b_model,
ratio,
save_to,
precision,
save_precision,
], ],
) )

View File

@ -11,7 +11,11 @@ document_symbol = '\U0001F4C4' # 📄
def resize_lora( def resize_lora(
model, new_rank, save_to, save_precision, device, model,
new_rank,
save_to,
save_precision,
device,
): ):
# Check for caption_text_input # Check for caption_text_input
if model == '': if model == '':
@ -22,7 +26,7 @@ def resize_lora(
if not os.path.isfile(model): if not os.path.isfile(model):
msgbox('The provided model is not a file') msgbox('The provided model is not a file')
return return
if device == '': if device == '':
device = 'cuda' device = 'cuda'
@ -46,13 +50,11 @@ def resize_lora(
def gradio_resize_lora_tab(): def gradio_resize_lora_tab():
with gr.Tab('Resize LoRA'): with gr.Tab('Resize LoRA'):
gr.Markdown( gr.Markdown('This utility can resize a LoRA.')
'This utility can resize a LoRA.'
)
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False) lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
with gr.Row(): with gr.Row():
model = gr.Textbox( model = gr.Textbox(
label='Source LoRA', label='Source LoRA',
@ -68,9 +70,15 @@ def gradio_resize_lora_tab():
outputs=model, outputs=model,
) )
with gr.Row(): with gr.Row():
new_rank = gr.Slider(label="Desired LoRA rank", minimum=1, maximum=1024, step=1, value=4, new_rank = gr.Slider(
interactive=True,) label='Desired LoRA rank',
minimum=1,
maximum=1024,
step=1,
value=4,
interactive=True,
)
with gr.Row(): with gr.Row():
save_to = gr.Textbox( save_to = gr.Textbox(
label='Save to', label='Save to',
@ -81,7 +89,9 @@ def gradio_resize_lora_tab():
folder_symbol, elem_id='open_folder_small' folder_symbol, elem_id='open_folder_small'
) )
button_save_to.click( button_save_to.click(
get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to get_saveasfilename_path,
inputs=[save_to, lora_ext, lora_ext_name],
outputs=save_to,
) )
save_precision = gr.Dropdown( save_precision = gr.Dropdown(
label='Save precison', label='Save precison',
@ -99,6 +109,11 @@ def gradio_resize_lora_tab():
convert_button.click( convert_button.click(
resize_lora, resize_lora,
inputs=[model, new_rank, save_to, save_precision, device, inputs=[
model,
new_rank,
save_to,
save_precision,
device,
], ],
) )

View File

@ -2,7 +2,11 @@ import gradio as gr
from easygui import msgbox from easygui import msgbox
import subprocess import subprocess
import os import os
from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path from .common_gui import (
get_saveasfilename_path,
get_any_file_path,
get_file_path,
)
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '\U0001f504' # 🔄
@ -30,9 +34,11 @@ def verify_lora(
# Run the command # Run the command
subprocess.run(run_cmd) subprocess.run(run_cmd)
process = subprocess.Popen(run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) process = subprocess.Popen(
run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
output, error = process.communicate() output, error = process.communicate()
return (output.decode(), error.decode()) return (output.decode(), error.decode())
@ -46,10 +52,10 @@ def gradio_verify_lora_tab():
gr.Markdown( gr.Markdown(
'This utility can verify a LoRA network to make sure it is properly trained.' 'This utility can verify a LoRA network to make sure it is properly trained.'
) )
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False) lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
with gr.Row(): with gr.Row():
lora_model = gr.Textbox( lora_model = gr.Textbox(
label='LoRA model', label='LoRA model',
@ -64,7 +70,7 @@ def gradio_verify_lora_tab():
inputs=[lora_model, lora_ext, lora_ext_name], inputs=[lora_model, lora_ext, lora_ext_name],
outputs=lora_model, outputs=lora_model,
) )
verify_button = gr.Button('Verify', variant="primary") verify_button = gr.Button('Verify', variant='primary')
lora_model_verif_output = gr.Textbox( lora_model_verif_output = gr.Textbox(
label='Output', label='Output',
@ -73,7 +79,7 @@ def gradio_verify_lora_tab():
lines=1, lines=1,
max_lines=10, max_lines=10,
) )
lora_model_verif_error = gr.Textbox( lora_model_verif_error = gr.Textbox(
label='Error', label='Error',
placeholder='Verification error', placeholder='Verification error',
@ -87,5 +93,5 @@ def gradio_verify_lora_tab():
inputs=[ inputs=[
lora_model, lora_model,
], ],
outputs=[lora_model_verif_output, lora_model_verif_error] outputs=[lora_model_verif_output, lora_model_verif_error],
) )

View File

@ -14,7 +14,7 @@ def caption_images(train_data_dir, caption_extension, batch_size, thresh):
if train_data_dir == '': if train_data_dir == '':
msgbox('Image folder is missing...') msgbox('Image folder is missing...')
return return
if caption_extension == '': if caption_extension == '':
msgbox('Please provide an extension for the caption files.') msgbox('Please provide an extension for the caption files.')
return return

View File

@ -91,9 +91,14 @@ def save_configuration(
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
network_alpha, network_alpha,
training_comment, keep_tokens, training_comment,
lr_scheduler_num_cycles, lr_scheduler_power, keep_tokens,
lr_scheduler_num_cycles,
lr_scheduler_power,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -182,9 +187,14 @@ def open_configuration(
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
network_alpha, network_alpha,
training_comment, keep_tokens, training_comment,
lr_scheduler_num_cycles, lr_scheduler_power, keep_tokens,
lr_scheduler_num_cycles,
lr_scheduler_power,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -257,9 +267,14 @@ def train_model(
max_train_epochs, max_train_epochs,
max_data_loader_n_workers, max_data_loader_n_workers,
network_alpha, network_alpha,
training_comment, keep_tokens, training_comment,
lr_scheduler_num_cycles, lr_scheduler_power, keep_tokens,
lr_scheduler_num_cycles,
lr_scheduler_power,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -281,12 +296,18 @@ def train_model(
if output_dir == '': if output_dir == '':
msgbox('Output folder path is missing') msgbox('Output folder path is missing')
return return
if int(bucket_reso_steps) < 1:
msgbox('Bucket resolution steps need to be greater than 0')
return
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
if stop_text_encoder_training_pct > 0: if stop_text_encoder_training_pct > 0:
msgbox('Output "stop text encoder training" is not yet supported. Ignoring') msgbox(
'Output "stop text encoder training" is not yet supported. Ignoring'
)
stop_text_encoder_training_pct = 0 stop_text_encoder_training_pct = 0
# If string is empty set string to 0. # If string is empty set string to 0.
@ -358,9 +379,9 @@ def train_model(
print(f'lr_warmup_steps = {lr_warmup_steps}') print(f'lr_warmup_steps = {lr_warmup_steps}')
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"' run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"'
run_cmd += f' --bucket_reso_steps=1 --bucket_no_upscale' # --random_crop' run_cmd += f' --bucket_reso_steps=1 --bucket_no_upscale' # --random_crop'
if v2: if v2:
run_cmd += ' --v2' run_cmd += ' --v2'
if v_parameterization: if v_parameterization:
@ -390,7 +411,7 @@ def train_model(
if not float(prior_loss_weight) == 1.0: if not float(prior_loss_weight) == 1.0:
run_cmd += f' --prior_loss_weight={prior_loss_weight}' run_cmd += f' --prior_loss_weight={prior_loss_weight}'
run_cmd += f' --network_module=networks.lora' run_cmd += f' --network_module=networks.lora'
if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0): if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0): if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0):
run_cmd += f' --text_encoder_lr={text_encoder_lr}' run_cmd += f' --text_encoder_lr={text_encoder_lr}'
@ -402,14 +423,12 @@ def train_model(
run_cmd += f' --unet_lr={unet_lr}' run_cmd += f' --unet_lr={unet_lr}'
run_cmd += f' --network_train_unet_only' run_cmd += f' --network_train_unet_only'
else: else:
if float(text_encoder_lr) == 0: if float(text_encoder_lr) == 0:
msgbox( msgbox('Please input learning rate values.')
'Please input learning rate values.'
)
return return
run_cmd += f' --network_dim={network_dim}' run_cmd += f' --network_dim={network_dim}'
if not lora_network_weights == '': if not lora_network_weights == '':
run_cmd += f' --network_weights="{lora_network_weights}"' run_cmd += f' --network_weights="{lora_network_weights}"'
if int(gradient_accumulation_steps) > 1: if int(gradient_accumulation_steps) > 1:
@ -454,6 +473,9 @@ def train_model(
use_8bit_adam=use_8bit_adam, use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens, keep_tokens=keep_tokens,
persistent_data_loader_workers=persistent_data_loader_workers, persistent_data_loader_workers=persistent_data_loader_workers,
bucket_no_upscale=bucket_no_upscale,
random_crop=random_crop,
bucket_reso_steps=bucket_reso_steps,
) )
print(run_cmd) print(run_cmd)
@ -675,11 +697,13 @@ def lora_tab(
label='Prior loss weight', value=1.0 label='Prior loss weight', value=1.0
) )
lr_scheduler_num_cycles = gr.Textbox( lr_scheduler_num_cycles = gr.Textbox(
label='LR number of cycles', placeholder='(Optional) For Cosine with restart and polynomial only' label='LR number of cycles',
placeholder='(Optional) For Cosine with restart and polynomial only',
) )
lr_scheduler_power = gr.Textbox( lr_scheduler_power = gr.Textbox(
label='LR power', placeholder='(Optional) For Cosine with restart and polynomial only' label='LR power',
placeholder='(Optional) For Cosine with restart and polynomial only',
) )
( (
use_8bit_adam, use_8bit_adam,
@ -698,6 +722,9 @@ def lora_tab(
max_data_loader_n_workers, max_data_loader_n_workers,
keep_tokens, keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -719,7 +746,6 @@ def lora_tab(
gradio_merge_lora_tab() gradio_merge_lora_tab()
gradio_resize_lora_tab() gradio_resize_lora_tab()
gradio_verify_lora_tab() gradio_verify_lora_tab()
button_run = gr.Button('Train model') button_run = gr.Button('Train model')
@ -773,8 +799,12 @@ def lora_tab(
network_alpha, network_alpha,
training_comment, training_comment,
keep_tokens, keep_tokens,
lr_scheduler_num_cycles, lr_scheduler_power, lr_scheduler_num_cycles,
lr_scheduler_power,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
] ]
button_open_config.click( button_open_config.click(

View File

@ -82,8 +82,18 @@ def save_configuration(
max_data_loader_n_workers, max_data_loader_n_workers,
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens, model_list,
token_string,
init_word,
num_vectors_per_token,
max_train_steps,
weights,
template,
keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -171,8 +181,18 @@ def open_configuration(
max_data_loader_n_workers, max_data_loader_n_workers,
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens, model_list,
token_string,
init_word,
num_vectors_per_token,
max_train_steps,
weights,
template,
keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -241,8 +261,17 @@ def train_model(
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, # Keep this. Yes, it is unused here but required given the common list used model_list, # Keep this. Yes, it is unused here but required given the common list used
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens, token_string,
init_word,
num_vectors_per_token,
max_train_steps,
weights,
template,
keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -264,15 +293,15 @@ def train_model(
if output_dir == '': if output_dir == '':
msgbox('Output folder path is missing') msgbox('Output folder path is missing')
return return
if token_string == '': if token_string == '':
msgbox('Token string is missing') msgbox('Token string is missing')
return return
if init_word == '': if init_word == '':
msgbox('Init word is missing') msgbox('Init word is missing')
return return
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
@ -332,7 +361,7 @@ def train_model(
) )
else: else:
max_train_steps = int(max_train_steps) max_train_steps = int(max_train_steps)
print(f'max_train_steps = {max_train_steps}') print(f'max_train_steps = {max_train_steps}')
# calculate stop encoder training # calculate stop encoder training
@ -421,6 +450,9 @@ def train_model(
use_8bit_adam=use_8bit_adam, use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens, keep_tokens=keep_tokens,
persistent_data_loader_workers=persistent_data_loader_workers, persistent_data_loader_workers=persistent_data_loader_workers,
bucket_no_upscale=bucket_no_upscale,
random_crop=random_crop,
bucket_reso_steps=bucket_reso_steps,
) )
run_cmd += f' --token_string="{token_string}"' run_cmd += f' --token_string="{token_string}"'
run_cmd += f' --init_word="{init_word}"' run_cmd += f' --init_word="{init_word}"'
@ -431,7 +463,7 @@ def train_model(
run_cmd += f' --use_object_template' run_cmd += f' --use_object_template'
elif template == 'style template': elif template == 'style template':
run_cmd += f' --use_style_template' run_cmd += f' --use_style_template'
print(run_cmd) print(run_cmd)
# Run the command # Run the command
subprocess.run(run_cmd) subprocess.run(run_cmd)
@ -576,9 +608,7 @@ def ti_tab(
label='Resume TI training', label='Resume TI training',
placeholder='(Optional) Path to existing TI embeding file to keep training', placeholder='(Optional) Path to existing TI embeding file to keep training',
) )
weights_file_input = gr.Button( weights_file_input = gr.Button('📂', elem_id='open_folder_small')
'📂', elem_id='open_folder_small'
)
weights_file_input.click(get_file_path, outputs=weights) weights_file_input.click(get_file_path, outputs=weights)
with gr.Row(): with gr.Row():
token_string = gr.Textbox( token_string = gr.Textbox(
@ -676,6 +706,9 @@ def ti_tab(
max_data_loader_n_workers, max_data_loader_n_workers,
keep_tokens, keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -739,9 +772,17 @@ def ti_tab(
mem_eff_attn, mem_eff_attn,
gradient_accumulation_steps, gradient_accumulation_steps,
model_list, model_list,
token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, token_string,
init_word,
num_vectors_per_token,
max_train_steps,
weights,
template,
keep_tokens, keep_tokens,
persistent_data_loader_workers, persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps,
] ]
button_open_config.click( button_open_config.click(

View File

@ -1,66 +0,0 @@
import os
import cv2
import argparse
import shutil
import math
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2):
# Calculate max_pixels from max_resolution string
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
# Create destination folder if it does not exist
if not os.path.exists(dst_img_folder):
os.makedirs(dst_img_folder)
# Iterate through all files in src_img_folder
for filename in os.listdir(src_img_folder):
# Check if the image is png, jpg or webp
if not filename.endswith(('.png', '.jpg', '.webp')):
# Copy the file to the destination folder if not png, jpg or webp
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
continue
# Load image
img = cv2.imread(os.path.join(src_img_folder, filename))
# Calculate current number of pixels
current_pixels = img.shape[0] * img.shape[1]
# Check if the image needs resizing
if current_pixels > max_pixels:
# Calculate scaling factor
scale_factor = max_pixels / current_pixels
# Calculate new dimensions
new_height = int(img.shape[0] * math.sqrt(scale_factor))
new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image
img = cv2.resize(img, (new_width, new_height))
# Calculate the new height and width that are divisible by divisible_by
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
# Center crop the image to the calculated dimensions
y = int((img.shape[0] - new_height) / 2)
x = int((img.shape[1] - new_width) / 2)
img = img[y:y + new_height, x:x + new_width]
# Save resized image in dst_img_folder
cv2.imwrite(os.path.join(dst_img_folder, filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]}")
def main():
parser = argparse.ArgumentParser(description='Resize images in a folder to a specified max resolution')
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images')
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images')
parser.add_argument('--max_resolution', type=str, help='Maximum resolution in the format "512x512"', default="512x512")
parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value', default=2)
args = parser.parse_args()
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,76 @@
import os
import cv2
import argparse
import shutil
import math
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2):
# Split the max_resolution string by "," and strip any whitespaces
max_resolutions = [res.strip() for res in max_resolution.split(',')]
# # Calculate max_pixels from max_resolution string
# max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
# Create destination folder if it does not exist
if not os.path.exists(dst_img_folder):
os.makedirs(dst_img_folder)
# Iterate through all files in src_img_folder
for filename in os.listdir(src_img_folder):
# Check if the image is png, jpg or webp
if not filename.endswith(('.png', '.jpg', '.webp')):
# Copy the file to the destination folder if not png, jpg or webp
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
continue
# Load image
img = cv2.imread(os.path.join(src_img_folder, filename))
for max_resolution in max_resolutions:
# Calculate max_pixels from max_resolution string
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
# Calculate current number of pixels
current_pixels = img.shape[0] * img.shape[1]
# Check if the image needs resizing
if current_pixels > max_pixels:
# Calculate scaling factor
scale_factor = max_pixels / current_pixels
# Calculate new dimensions
new_height = int(img.shape[0] * math.sqrt(scale_factor))
new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image
img = cv2.resize(img, (new_width, new_height))
# Calculate the new height and width that are divisible by divisible_by
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
# Center crop the image to the calculated dimensions
y = int((img.shape[0] - new_height) / 2)
x = int((img.shape[1] - new_width) / 2)
img = img[y:y + new_height, x:x + new_width]
# Split filename into base and extension
base, ext = os.path.splitext(filename)
new_filename = base + '+' + max_resolution + '.jpg'
# Save resized image in dst_img_folder
cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
def main():
parser = argparse.ArgumentParser(description='Resize images in a folder to a specified max resolution(s)')
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images')
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images')
parser.add_argument('--max_resolution', type=str, help='Maximum resolution(s) in the format "512x512,384x384, etc, etc"', default="512x512,384x384,256x256,128x128")
parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value', default=1)
args = parser.parse_args()
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution)
if __name__ == '__main__':
main()