diff --git a/dreambooth_gui.py b/dreambooth_gui.py index f64480b..cdcb85b 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -82,8 +82,12 @@ def save_configuration( max_data_loader_n_workers, mem_eff_attn, gradient_accumulation_steps, - model_list, keep_tokens, + model_list, + keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -167,8 +171,12 @@ def open_configuration( max_data_loader_n_workers, mem_eff_attn, gradient_accumulation_steps, - model_list, keep_tokens, + model_list, + keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -239,6 +247,9 @@ def train_model( model_list, # Keep this. Yes, it is unused here but required given the common list used keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -402,6 +413,9 @@ def train_model( use_8bit_adam=use_8bit_adam, keep_tokens=keep_tokens, persistent_data_loader_workers=persistent_data_loader_workers, + bucket_no_upscale=bucket_no_upscale, + random_crop=random_crop, + bucket_reso_steps=bucket_reso_steps, ) print(run_cmd) @@ -610,6 +624,9 @@ def dreambooth_tab( max_data_loader_n_workers, keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -675,6 +692,9 @@ def dreambooth_tab( model_list, keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ] button_open_config.click( diff --git a/finetune_gui.py b/finetune_gui.py index 49dcd52..80c887f 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -78,8 +78,12 @@ def save_configuration( color_aug, model_list, cache_latents, - use_latent_files, keep_tokens, + use_latent_files, + keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -169,8 +173,12 @@ def open_config_file( color_aug, model_list, cache_latents, - use_latent_files, keep_tokens, + use_latent_files, + keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -245,8 +253,12 @@ def train_model( color_aug, model_list, # Keep this. Yes, it is unused here but required given the common list used cache_latents, - use_latent_files, keep_tokens, + use_latent_files, + keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ): # create caption json file if generate_caption_database: @@ -295,7 +307,11 @@ def train_model( subprocess.run(run_cmd) image_num = len( - [f for f in os.listdir(image_folder) if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.webp')] + [ + f + for f in os.listdir(image_folder) + if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.webp') + ] ) print(f'image_num = {image_num}') @@ -386,6 +402,9 @@ def train_model( use_8bit_adam=use_8bit_adam, keep_tokens=keep_tokens, persistent_data_loader_workers=persistent_data_loader_workers, + bucket_no_upscale=bucket_no_upscale, + random_crop=random_crop, + bucket_reso_steps=bucket_reso_steps, ) print(run_cmd) @@ -592,6 +611,9 @@ def finetune_tab(): max_data_loader_n_workers, keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -653,6 +675,9 @@ def finetune_tab(): use_latent_files, keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ] button_run.click(train_model, inputs=settings_list) diff --git a/kohya_gui.py b/kohya_gui.py index 1031810..fa51fd6 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -19,7 +19,7 @@ def UI(username, password): print('Load CSS...') css += file.read() + '\n' - interface = gr.Blocks(css=css, title="Kohya_ss GUI") + interface = gr.Blocks(css=css, title='Kohya_ss GUI') with interface: with gr.Tab('Dreambooth'): diff --git a/library/basic_caption_gui.py b/library/basic_caption_gui.py index 57ff558..2412dfb 100644 --- a/library/basic_caption_gui.py +++ b/library/basic_caption_gui.py @@ -10,13 +10,15 @@ def caption_images( overwrite_input, caption_file_ext, prefix, - postfix, find, replace + postfix, + find, + replace, ): # Check for images_dir_input if images_dir_input == '': msgbox('Image folder is missing...') return - + if caption_file_ext == '': msgbox('Please provide an extension for the caption files.') return @@ -39,7 +41,7 @@ def caption_images( subprocess.run(run_cmd) if overwrite_input: - if not prefix=='' or not postfix=='': + if not prefix == '' or not postfix == '': # Add prefix and postfix add_pre_postfix( folder=images_dir_input, @@ -47,7 +49,7 @@ def caption_images( prefix=prefix, postfix=postfix, ) - if not find=='': + if not find == '': find_replace( folder=images_dir_input, caption_file_ext=caption_file_ext, @@ -134,6 +136,7 @@ def gradio_basic_caption_gui_tab(): caption_file_ext, prefix, postfix, - find, replace + find, + replace, ], ) diff --git a/library/blip_caption_gui.py b/library/blip_caption_gui.py index b9ae8ae..61acd75 100644 --- a/library/blip_caption_gui.py +++ b/library/blip_caption_gui.py @@ -26,7 +26,7 @@ def caption_images( if train_data_dir == '': msgbox('Image folder is missing...') return - + if caption_file_ext == '': msgbox('Please provide an extension for the caption files.') return diff --git a/library/common_gui.py b/library/common_gui.py index 4634ce2..69104de 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -9,6 +9,7 @@ refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 + def get_dir_and_file(file_path): dir_path, file_name = os.path.split(file_path) return (dir_path, file_name) @@ -200,7 +201,7 @@ def find_replace(folder='', caption_file_ext='.caption', find='', replace=''): files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)] for file in files: - with open(os.path.join(folder, file), 'r', errors="ignore") as f: + with open(os.path.join(folder, file), 'r', errors='ignore') as f: content = f.read() f.close content = content.replace(find, replace) @@ -304,7 +305,8 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): ### ### Gradio common GUI section ### - + + def gradio_config(): with gr.Accordion('Configuration file', open=False): with gr.Row(): @@ -318,7 +320,13 @@ def gradio_config(): placeholder="type the configuration file path or use the 'Open' button above to select it...", interactive=True, ) - return (button_open_config, button_save_config, button_save_as_config, config_file_name) + return ( + button_open_config, + button_save_config, + button_save_as_config, + config_file_name, + ) + def gradio_source_model(): with gr.Tab('Source model'): @@ -382,9 +390,20 @@ def gradio_source_model(): v_parameterization, ], ) - return (pretrained_model_name_or_path, v2, v_parameterization, save_model_as, model_list) + return ( + pretrained_model_name_or_path, + v2, + v_parameterization, + save_model_as, + model_list, + ) -def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', lr_warmup_value='0'): + +def gradio_training( + learning_rate_value='1e-6', + lr_scheduler_value='constant', + lr_warmup_value='0', +): with gr.Row(): train_batch_size = gr.Slider( minimum=1, @@ -394,9 +413,7 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l step=1, ) epoch = gr.Textbox(label='Epoch', value=1) - save_every_n_epochs = gr.Textbox( - label='Save every N epochs', value=1 - ) + save_every_n_epochs = gr.Textbox(label='Save every N epochs', value=1) caption_extension = gr.Textbox( label='Caption Extension', placeholder='(Optional) Extension for caption files. default: .caption', @@ -429,7 +446,9 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l ) seed = gr.Textbox(label='Seed', value=1234) with gr.Row(): - learning_rate = gr.Textbox(label='Learning rate', value=learning_rate_value) + learning_rate = gr.Textbox( + label='Learning rate', value=learning_rate_value + ) lr_scheduler = gr.Dropdown( label='LR Scheduler', choices=[ @@ -442,7 +461,9 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l ], value=lr_scheduler_value, ) - lr_warmup = gr.Textbox(label='LR warmup (% of steps)', value=lr_warmup_value) + lr_warmup = gr.Textbox( + label='LR warmup (% of steps)', value=lr_warmup_value + ) cache_latents = gr.Checkbox(label='Cache latent', value=True) return ( learning_rate, @@ -459,50 +480,38 @@ def gradio_training(learning_rate_value='1e-6', lr_scheduler_value='constant', l cache_latents, ) + def run_cmd_training(**kwargs): options = [ f' --learning_rate="{kwargs.get("learning_rate", "")}"' if kwargs.get('learning_rate') else '', - f' --lr_scheduler="{kwargs.get("lr_scheduler", "")}"' if kwargs.get('lr_scheduler') else '', - f' --lr_warmup_steps="{kwargs.get("lr_warmup_steps", "")}"' if kwargs.get('lr_warmup_steps') else '', - f' --train_batch_size="{kwargs.get("train_batch_size", "")}"' if kwargs.get('train_batch_size') else '', - f' --max_train_steps="{kwargs.get("max_train_steps", "")}"' if kwargs.get('max_train_steps') else '', - f' --save_every_n_epochs="{kwargs.get("save_every_n_epochs", "")}"' if kwargs.get('save_every_n_epochs') else '', - f' --mixed_precision="{kwargs.get("mixed_precision", "")}"' if kwargs.get('mixed_precision') else '', - f' --save_precision="{kwargs.get("save_precision", "")}"' if kwargs.get('save_precision') else '', - - f' --seed="{kwargs.get("seed", "")}"' - if kwargs.get('seed') - else '', - + f' --seed="{kwargs.get("seed", "")}"' if kwargs.get('seed') else '', f' --caption_extension="{kwargs.get("caption_extension", "")}"' if kwargs.get('caption_extension') else '', - ' --cache_latents' if kwargs.get('cache_latents') else '', - ] run_cmd = ''.join(options) return run_cmd @@ -532,9 +541,7 @@ def gradio_advanced_training(): gradient_checkpointing = gr.Checkbox( label='Gradient checkpointing', value=False ) - shuffle_caption = gr.Checkbox( - label='Shuffle caption', value=False - ) + shuffle_caption = gr.Checkbox(label='Shuffle caption', value=False) persistent_data_loader_workers = gr.Checkbox( label='Persistent data loader', value=False ) @@ -544,10 +551,18 @@ def gradio_advanced_training(): with gr.Row(): use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) xformers = gr.Checkbox(label='Use xformers', value=True) - color_aug = gr.Checkbox( - label='Color augmentation', value=False - ) + color_aug = gr.Checkbox(label='Color augmentation', value=False) flip_aug = gr.Checkbox(label='Flip augmentation', value=False) + with gr.Row(): + bucket_no_upscale = gr.Checkbox( + label="Don't upscale bucket resolution", value=True + ) + random_crop = gr.Checkbox( + label='Random crop instead of center crop', value=False + ) + bucket_reso_steps = gr.Number( + label='Bucket resolution steps', value=64 + ) with gr.Row(): save_state = gr.Checkbox(label='Save training state', value=False) resume = gr.Textbox( @@ -581,55 +596,53 @@ def gradio_advanced_training(): max_data_loader_n_workers, keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ) + def run_cmd_advanced_training(**kwargs): options = [ f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"' if kwargs.get('max_train_epochs') else '', - f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"' if kwargs.get('max_data_loader_n_workers') else '', - f' --max_token_length={kwargs.get("max_token_length", "")}' if int(kwargs.get('max_token_length', 75)) > 75 else '', - f' --clip_skip={kwargs.get("clip_skip", "")}' if int(kwargs.get('clip_skip', 1)) > 1 else '', - f' --resume="{kwargs.get("resume", "")}"' if kwargs.get('resume') else '', - f' --keep_tokens="{kwargs.get("keep_tokens", "")}"' if int(kwargs.get('keep_tokens', 0)) > 0 else '', + f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}' + if int(kwargs.get('bucket_reso_steps', 64)) >= 1 + else '', + ' --save_state' if kwargs.get('save_state') else '', - ' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '', - ' --color_aug' if kwargs.get('color_aug') else '', - ' --flip_aug' if kwargs.get('flip_aug') else '', - ' --shuffle_caption' if kwargs.get('shuffle_caption') else '', - - ' --gradient_checkpointing' if kwargs.get('gradient_checkpointing') else '', - + ' --gradient_checkpointing' + if kwargs.get('gradient_checkpointing') + else '', ' --full_fp16' if kwargs.get('full_fp16') else '', - ' --xformers' if kwargs.get('xformers') else '', - ' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '', - - ' --persistent_data_loader_workers' if kwargs.get('persistent_data_loader_workers') else '', - + ' --persistent_data_loader_workers' + if kwargs.get('persistent_data_loader_workers') + else '', + ' --bucket_no_upscale' if kwargs.get('bucket_no_upscale') else '', + ' --random_crop' if kwargs.get('random_crop') else '', ] run_cmd = ''.join(options) return run_cmd - diff --git a/library/dreambooth_folder_creation_gui.py b/library/dreambooth_folder_creation_gui.py index 98c3038..f8e23b2 100644 --- a/library/dreambooth_folder_creation_gui.py +++ b/library/dreambooth_folder_creation_gui.py @@ -191,9 +191,7 @@ def gradio_dreambooth_folder_creation_tab( util_training_dir_output, ], ) - button_copy_info_to_Folders_tab = gr.Button( - 'Copy info to Folders Tab' - ) + button_copy_info_to_Folders_tab = gr.Button('Copy info to Folders Tab') button_copy_info_to_Folders_tab.click( copy_info_to_Folders_tab, inputs=[util_training_dir_output], diff --git a/library/extract_lora_gui.py b/library/extract_lora_gui.py index 0f14dd8..5587f00 100644 --- a/library/extract_lora_gui.py +++ b/library/extract_lora_gui.py @@ -2,7 +2,11 @@ import gradio as gr from easygui import msgbox import subprocess import os -from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path +from .common_gui import ( + get_saveasfilename_path, + get_any_file_path, + get_file_path, +) folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 @@ -11,13 +15,18 @@ document_symbol = '\U0001F4C4' # 📄 def extract_lora( - model_tuned, model_org, save_to, save_precision, dim, v2, + model_tuned, + model_org, + save_to, + save_precision, + dim, + v2, ): # Check for caption_text_input if model_tuned == '': msgbox('Invalid finetuned model file') return - + if model_org == '': msgbox('Invalid base model file') return @@ -26,12 +35,14 @@ def extract_lora( if not os.path.isfile(model_tuned): msgbox('The provided finetuned model is not a file') return - + if not os.path.isfile(model_org): msgbox('The provided base model is not a file') return - run_cmd = f'.\\venv\Scripts\python.exe "networks\extract_lora_from_models.py"' + run_cmd = ( + f'.\\venv\Scripts\python.exe "networks\extract_lora_from_models.py"' + ) run_cmd += f' --save_precision {save_precision}' run_cmd += f' --save_to "{save_to}"' run_cmd += f' --model_org "{model_org}"' @@ -60,7 +71,7 @@ def gradio_extract_lora_tab(): lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) model_ext = gr.Textbox(value='*.ckpt *.safetensors', visible=False) model_ext_name = gr.Textbox(value='Model types', visible=False) - + with gr.Row(): model_tuned = gr.Textbox( label='Finetuned model', @@ -75,7 +86,7 @@ def gradio_extract_lora_tab(): inputs=[model_tuned, model_ext, model_ext_name], outputs=model_tuned, ) - + model_org = gr.Textbox( label='Stable Diffusion base model', placeholder='Stable Diffusion original model: ckpt or safetensors file', @@ -99,7 +110,9 @@ def gradio_extract_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_save_to.click( - get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to + get_saveasfilename_path, + inputs=[save_to, lora_ext, lora_ext_name], + outputs=save_to, ) save_precision = gr.Dropdown( label='Save precison', @@ -122,6 +135,5 @@ def gradio_extract_lora_tab(): extract_button.click( extract_lora, - inputs=[model_tuned, model_org, save_to, save_precision, dim, v2 - ], + inputs=[model_tuned, model_org, save_to, save_precision, dim, v2], ) diff --git a/library/git_caption_gui.py b/library/git_caption_gui.py index 1863324..c65fb3c 100644 --- a/library/git_caption_gui.py +++ b/library/git_caption_gui.py @@ -15,11 +15,11 @@ def caption_images( prefix, postfix, ): - # Check for images_dir_input + # Check for images_dir_input if train_data_dir == '': msgbox('Image folder is missing...') return - + if caption_ext == '': msgbox('Please provide an extension for the caption files.') return @@ -29,7 +29,9 @@ def caption_images( if not model_id == '': run_cmd += f' --model_id="{model_id}"' run_cmd += f' --batch_size="{int(batch_size)}"' - run_cmd += f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"' + run_cmd += ( + f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"' + ) run_cmd += f' --max_length="{int(max_length)}"' if caption_ext != '': run_cmd += f' --caption_extension="{caption_ext}"' @@ -105,8 +107,9 @@ def gradio_git_caption_gui_tab(): value=75, label='Max length', interactive=True ) model_id = gr.Textbox( - label="Model", - placeholder="(Optional) model id for GIT in Hugging Face", interactive=True + label='Model', + placeholder='(Optional) model id for GIT in Hugging Face', + interactive=True, ) caption_button = gr.Button('Caption images') diff --git a/library/merge_lora_gui.py b/library/merge_lora_gui.py index 0271963..d51fda2 100644 --- a/library/merge_lora_gui.py +++ b/library/merge_lora_gui.py @@ -2,7 +2,11 @@ import gradio as gr from easygui import msgbox import subprocess import os -from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path +from .common_gui import ( + get_saveasfilename_path, + get_any_file_path, + get_file_path, +) folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 @@ -11,13 +15,18 @@ document_symbol = '\U0001F4C4' # 📄 def merge_lora( - lora_a_model, lora_b_model, ratio, save_to, precision, save_precision, + lora_a_model, + lora_b_model, + ratio, + save_to, + precision, + save_precision, ): # Check for caption_text_input if lora_a_model == '': msgbox('Invalid model A file') return - + if lora_b_model == '': msgbox('Invalid model B file') return @@ -26,7 +35,7 @@ def merge_lora( if not os.path.isfile(lora_a_model): msgbox('The provided model A is not a file') return - + if not os.path.isfile(lora_b_model): msgbox('The provided model B is not a file') return @@ -54,13 +63,11 @@ def merge_lora( def gradio_merge_lora_tab(): with gr.Tab('Merge LoRA'): - gr.Markdown( - 'This utility can merge two LoRA networks together.' - ) - + gr.Markdown('This utility can merge two LoRA networks together.') + lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False) lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) - + with gr.Row(): lora_a_model = gr.Textbox( label='LoRA model "A"', @@ -75,7 +82,7 @@ def gradio_merge_lora_tab(): inputs=[lora_a_model, lora_ext, lora_ext_name], outputs=lora_a_model, ) - + lora_b_model = gr.Textbox( label='LoRA model "B"', placeholder='Path to the LoRA B model', @@ -90,9 +97,15 @@ def gradio_merge_lora_tab(): outputs=lora_b_model, ) with gr.Row(): - ratio = gr.Slider(label="Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B", minimum=0, maximum=1, step=0.01, value=0.5, - interactive=True,) - + ratio = gr.Slider( + label='Merge ratio (eg: 0.7 mean 70% of model A and 30% of model B', + minimum=0, + maximum=1, + step=0.01, + value=0.5, + interactive=True, + ) + with gr.Row(): save_to = gr.Textbox( label='Save to', @@ -103,7 +116,9 @@ def gradio_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_save_to.click( - get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to + get_saveasfilename_path, + inputs=[save_to, lora_ext, lora_ext_name], + outputs=save_to, ) precision = gr.Dropdown( label='Merge precison', @@ -122,6 +137,12 @@ def gradio_merge_lora_tab(): convert_button.click( merge_lora, - inputs=[lora_a_model, lora_b_model, ratio, save_to, precision, save_precision, + inputs=[ + lora_a_model, + lora_b_model, + ratio, + save_to, + precision, + save_precision, ], ) diff --git a/library/resize_lora_gui.py b/library/resize_lora_gui.py index 1ce6ebf..f5750ee 100644 --- a/library/resize_lora_gui.py +++ b/library/resize_lora_gui.py @@ -11,7 +11,11 @@ document_symbol = '\U0001F4C4' # 📄 def resize_lora( - model, new_rank, save_to, save_precision, device, + model, + new_rank, + save_to, + save_precision, + device, ): # Check for caption_text_input if model == '': @@ -22,7 +26,7 @@ def resize_lora( if not os.path.isfile(model): msgbox('The provided model is not a file') return - + if device == '': device = 'cuda' @@ -46,13 +50,11 @@ def resize_lora( def gradio_resize_lora_tab(): with gr.Tab('Resize LoRA'): - gr.Markdown( - 'This utility can resize a LoRA.' - ) - + gr.Markdown('This utility can resize a LoRA.') + lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False) lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) - + with gr.Row(): model = gr.Textbox( label='Source LoRA', @@ -68,9 +70,15 @@ def gradio_resize_lora_tab(): outputs=model, ) with gr.Row(): - new_rank = gr.Slider(label="Desired LoRA rank", minimum=1, maximum=1024, step=1, value=4, - interactive=True,) - + new_rank = gr.Slider( + label='Desired LoRA rank', + minimum=1, + maximum=1024, + step=1, + value=4, + interactive=True, + ) + with gr.Row(): save_to = gr.Textbox( label='Save to', @@ -81,7 +89,9 @@ def gradio_resize_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_save_to.click( - get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to + get_saveasfilename_path, + inputs=[save_to, lora_ext, lora_ext_name], + outputs=save_to, ) save_precision = gr.Dropdown( label='Save precison', @@ -99,6 +109,11 @@ def gradio_resize_lora_tab(): convert_button.click( resize_lora, - inputs=[model, new_rank, save_to, save_precision, device, + inputs=[ + model, + new_rank, + save_to, + save_precision, + device, ], ) diff --git a/library/verify_lora_gui.py b/library/verify_lora_gui.py index ada20d1..51fa510 100644 --- a/library/verify_lora_gui.py +++ b/library/verify_lora_gui.py @@ -2,7 +2,11 @@ import gradio as gr from easygui import msgbox import subprocess import os -from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path +from .common_gui import ( + get_saveasfilename_path, + get_any_file_path, + get_file_path, +) folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 @@ -30,9 +34,11 @@ def verify_lora( # Run the command subprocess.run(run_cmd) - process = subprocess.Popen(run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + process = subprocess.Popen( + run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) output, error = process.communicate() - + return (output.decode(), error.decode()) @@ -46,10 +52,10 @@ def gradio_verify_lora_tab(): gr.Markdown( 'This utility can verify a LoRA network to make sure it is properly trained.' ) - + lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False) lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) - + with gr.Row(): lora_model = gr.Textbox( label='LoRA model', @@ -64,7 +70,7 @@ def gradio_verify_lora_tab(): inputs=[lora_model, lora_ext, lora_ext_name], outputs=lora_model, ) - verify_button = gr.Button('Verify', variant="primary") + verify_button = gr.Button('Verify', variant='primary') lora_model_verif_output = gr.Textbox( label='Output', @@ -73,7 +79,7 @@ def gradio_verify_lora_tab(): lines=1, max_lines=10, ) - + lora_model_verif_error = gr.Textbox( label='Error', placeholder='Verification error', @@ -87,5 +93,5 @@ def gradio_verify_lora_tab(): inputs=[ lora_model, ], - outputs=[lora_model_verif_output, lora_model_verif_error] + outputs=[lora_model_verif_output, lora_model_verif_error], ) diff --git a/library/wd14_caption_gui.py b/library/wd14_caption_gui.py index d50c32a..3d97ebf 100644 --- a/library/wd14_caption_gui.py +++ b/library/wd14_caption_gui.py @@ -14,7 +14,7 @@ def caption_images(train_data_dir, caption_extension, batch_size, thresh): if train_data_dir == '': msgbox('Image folder is missing...') return - + if caption_extension == '': msgbox('Please provide an extension for the caption files.') return diff --git a/lora_gui.py b/lora_gui.py index 5639b2f..04c0cc1 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -91,9 +91,14 @@ def save_configuration( max_train_epochs, max_data_loader_n_workers, network_alpha, - training_comment, keep_tokens, - lr_scheduler_num_cycles, lr_scheduler_power, + training_comment, + keep_tokens, + lr_scheduler_num_cycles, + lr_scheduler_power, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -182,9 +187,14 @@ def open_configuration( max_train_epochs, max_data_loader_n_workers, network_alpha, - training_comment, keep_tokens, - lr_scheduler_num_cycles, lr_scheduler_power, + training_comment, + keep_tokens, + lr_scheduler_num_cycles, + lr_scheduler_power, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -257,9 +267,14 @@ def train_model( max_train_epochs, max_data_loader_n_workers, network_alpha, - training_comment, keep_tokens, - lr_scheduler_num_cycles, lr_scheduler_power, + training_comment, + keep_tokens, + lr_scheduler_num_cycles, + lr_scheduler_power, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -281,12 +296,18 @@ def train_model( if output_dir == '': msgbox('Output folder path is missing') return - + + if int(bucket_reso_steps) < 1: + msgbox('Bucket resolution steps need to be greater than 0') + return + if not os.path.exists(output_dir): os.makedirs(output_dir) - + if stop_text_encoder_training_pct > 0: - msgbox('Output "stop text encoder training" is not yet supported. Ignoring') + msgbox( + 'Output "stop text encoder training" is not yet supported. Ignoring' + ) stop_text_encoder_training_pct = 0 # If string is empty set string to 0. @@ -358,9 +379,9 @@ def train_model( print(f'lr_warmup_steps = {lr_warmup_steps}') run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"' - - run_cmd += f' --bucket_reso_steps=1 --bucket_no_upscale' # --random_crop' - + + run_cmd += f' --bucket_reso_steps=1 --bucket_no_upscale' # --random_crop' + if v2: run_cmd += ' --v2' if v_parameterization: @@ -390,7 +411,7 @@ def train_model( if not float(prior_loss_weight) == 1.0: run_cmd += f' --prior_loss_weight={prior_loss_weight}' run_cmd += f' --network_module=networks.lora' - + if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0): if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0): run_cmd += f' --text_encoder_lr={text_encoder_lr}' @@ -402,14 +423,12 @@ def train_model( run_cmd += f' --unet_lr={unet_lr}' run_cmd += f' --network_train_unet_only' else: - if float(text_encoder_lr) == 0: - msgbox( - 'Please input learning rate values.' - ) + if float(text_encoder_lr) == 0: + msgbox('Please input learning rate values.') return - + run_cmd += f' --network_dim={network_dim}' - + if not lora_network_weights == '': run_cmd += f' --network_weights="{lora_network_weights}"' if int(gradient_accumulation_steps) > 1: @@ -454,6 +473,9 @@ def train_model( use_8bit_adam=use_8bit_adam, keep_tokens=keep_tokens, persistent_data_loader_workers=persistent_data_loader_workers, + bucket_no_upscale=bucket_no_upscale, + random_crop=random_crop, + bucket_reso_steps=bucket_reso_steps, ) print(run_cmd) @@ -675,11 +697,13 @@ def lora_tab( label='Prior loss weight', value=1.0 ) lr_scheduler_num_cycles = gr.Textbox( - label='LR number of cycles', placeholder='(Optional) For Cosine with restart and polynomial only' + label='LR number of cycles', + placeholder='(Optional) For Cosine with restart and polynomial only', ) - + lr_scheduler_power = gr.Textbox( - label='LR power', placeholder='(Optional) For Cosine with restart and polynomial only' + label='LR power', + placeholder='(Optional) For Cosine with restart and polynomial only', ) ( use_8bit_adam, @@ -698,6 +722,9 @@ def lora_tab( max_data_loader_n_workers, keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -719,7 +746,6 @@ def lora_tab( gradio_merge_lora_tab() gradio_resize_lora_tab() gradio_verify_lora_tab() - button_run = gr.Button('Train model') @@ -773,8 +799,12 @@ def lora_tab( network_alpha, training_comment, keep_tokens, - lr_scheduler_num_cycles, lr_scheduler_power, + lr_scheduler_num_cycles, + lr_scheduler_power, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ] button_open_config.click( diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index d35a78c..b34ca6d 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -82,8 +82,18 @@ def save_configuration( max_data_loader_n_workers, mem_eff_attn, gradient_accumulation_steps, - model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens, + model_list, + token_string, + init_word, + num_vectors_per_token, + max_train_steps, + weights, + template, + keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -171,8 +181,18 @@ def open_configuration( max_data_loader_n_workers, mem_eff_attn, gradient_accumulation_steps, - model_list, token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens, + model_list, + token_string, + init_word, + num_vectors_per_token, + max_train_steps, + weights, + template, + keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -241,8 +261,17 @@ def train_model( mem_eff_attn, gradient_accumulation_steps, model_list, # Keep this. Yes, it is unused here but required given the common list used - token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, keep_tokens, + token_string, + init_word, + num_vectors_per_token, + max_train_steps, + weights, + template, + keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -264,15 +293,15 @@ def train_model( if output_dir == '': msgbox('Output folder path is missing') return - + if token_string == '': msgbox('Token string is missing') return - + if init_word == '': msgbox('Init word is missing') return - + if not os.path.exists(output_dir): os.makedirs(output_dir) @@ -332,7 +361,7 @@ def train_model( ) else: max_train_steps = int(max_train_steps) - + print(f'max_train_steps = {max_train_steps}') # calculate stop encoder training @@ -421,6 +450,9 @@ def train_model( use_8bit_adam=use_8bit_adam, keep_tokens=keep_tokens, persistent_data_loader_workers=persistent_data_loader_workers, + bucket_no_upscale=bucket_no_upscale, + random_crop=random_crop, + bucket_reso_steps=bucket_reso_steps, ) run_cmd += f' --token_string="{token_string}"' run_cmd += f' --init_word="{init_word}"' @@ -431,7 +463,7 @@ def train_model( run_cmd += f' --use_object_template' elif template == 'style template': run_cmd += f' --use_style_template' - + print(run_cmd) # Run the command subprocess.run(run_cmd) @@ -576,9 +608,7 @@ def ti_tab( label='Resume TI training', placeholder='(Optional) Path to existing TI embeding file to keep training', ) - weights_file_input = gr.Button( - '📂', elem_id='open_folder_small' - ) + weights_file_input = gr.Button('📂', elem_id='open_folder_small') weights_file_input.click(get_file_path, outputs=weights) with gr.Row(): token_string = gr.Textbox( @@ -676,6 +706,9 @@ def ti_tab( max_data_loader_n_workers, keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -739,9 +772,17 @@ def ti_tab( mem_eff_attn, gradient_accumulation_steps, model_list, - token_string, init_word, num_vectors_per_token, max_train_steps, weights, template, + token_string, + init_word, + num_vectors_per_token, + max_train_steps, + weights, + template, keep_tokens, persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, ] button_open_config.click( diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py deleted file mode 100644 index f1aecb3..0000000 --- a/tools/resize_images_to_resolution.py +++ /dev/null @@ -1,66 +0,0 @@ -import os -import cv2 -import argparse -import shutil -import math - -def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2): - # Calculate max_pixels from max_resolution string - max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) - - # Create destination folder if it does not exist - if not os.path.exists(dst_img_folder): - os.makedirs(dst_img_folder) - - # Iterate through all files in src_img_folder - for filename in os.listdir(src_img_folder): - # Check if the image is png, jpg or webp - if not filename.endswith(('.png', '.jpg', '.webp')): - # Copy the file to the destination folder if not png, jpg or webp - shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename)) - continue - - # Load image - img = cv2.imread(os.path.join(src_img_folder, filename)) - - # Calculate current number of pixels - current_pixels = img.shape[0] * img.shape[1] - - # Check if the image needs resizing - if current_pixels > max_pixels: - # Calculate scaling factor - scale_factor = max_pixels / current_pixels - - # Calculate new dimensions - new_height = int(img.shape[0] * math.sqrt(scale_factor)) - new_width = int(img.shape[1] * math.sqrt(scale_factor)) - - # Resize image - img = cv2.resize(img, (new_width, new_height)) - - # Calculate the new height and width that are divisible by divisible_by - new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by - new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by - - # Center crop the image to the calculated dimensions - y = int((img.shape[0] - new_height) / 2) - x = int((img.shape[1] - new_width) / 2) - img = img[y:y + new_height, x:x + new_width] - - # Save resized image in dst_img_folder - cv2.imwrite(os.path.join(dst_img_folder, filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100]) - - print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]}") - - -def main(): - parser = argparse.ArgumentParser(description='Resize images in a folder to a specified max resolution') - parser.add_argument('src_img_folder', type=str, help='Source folder containing the images') - parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images') - parser.add_argument('--max_resolution', type=str, help='Maximum resolution in the format "512x512"', default="512x512") - parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value', default=2) - args = parser.parse_args() - resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution) - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/tools/resize_images_to_resolutions.py b/tools/resize_images_to_resolutions.py new file mode 100644 index 0000000..04c1c2c --- /dev/null +++ b/tools/resize_images_to_resolutions.py @@ -0,0 +1,76 @@ +import os +import cv2 +import argparse +import shutil +import math + +def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2): + # Split the max_resolution string by "," and strip any whitespaces + max_resolutions = [res.strip() for res in max_resolution.split(',')] + + # # Calculate max_pixels from max_resolution string + # max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) + + # Create destination folder if it does not exist + if not os.path.exists(dst_img_folder): + os.makedirs(dst_img_folder) + + # Iterate through all files in src_img_folder + for filename in os.listdir(src_img_folder): + # Check if the image is png, jpg or webp + if not filename.endswith(('.png', '.jpg', '.webp')): + # Copy the file to the destination folder if not png, jpg or webp + shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename)) + continue + + # Load image + img = cv2.imread(os.path.join(src_img_folder, filename)) + + for max_resolution in max_resolutions: + # Calculate max_pixels from max_resolution string + max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) + + # Calculate current number of pixels + current_pixels = img.shape[0] * img.shape[1] + + # Check if the image needs resizing + if current_pixels > max_pixels: + # Calculate scaling factor + scale_factor = max_pixels / current_pixels + + # Calculate new dimensions + new_height = int(img.shape[0] * math.sqrt(scale_factor)) + new_width = int(img.shape[1] * math.sqrt(scale_factor)) + + # Resize image + img = cv2.resize(img, (new_width, new_height)) + + # Calculate the new height and width that are divisible by divisible_by + new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by + new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by + + # Center crop the image to the calculated dimensions + y = int((img.shape[0] - new_height) / 2) + x = int((img.shape[1] - new_width) / 2) + img = img[y:y + new_height, x:x + new_width] + + # Split filename into base and extension + base, ext = os.path.splitext(filename) + new_filename = base + '+' + max_resolution + '.jpg' + + # Save resized image in dst_img_folder + cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100]) + print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") + + +def main(): + parser = argparse.ArgumentParser(description='Resize images in a folder to a specified max resolution(s)') + parser.add_argument('src_img_folder', type=str, help='Source folder containing the images') + parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images') + parser.add_argument('--max_resolution', type=str, help='Maximum resolution(s) in the format "512x512,384x384, etc, etc"', default="512x512,384x384,256x256,128x128") + parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value', default=1) + args = parser.parse_args() + resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution) + +if __name__ == '__main__': + main() \ No newline at end of file