diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 914bb8f..a7b67ae 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -334,8 +334,8 @@ def train_model( run_cmd += ' --xformers' if shuffle_caption: run_cmd += ' --shuffle_caption' - if save_state: - run_cmd += ' --save_state' + # if save_state: + # run_cmd += ' --save_state' if color_aug: run_cmd += ' --color_aug' if flip_aug: @@ -368,8 +368,8 @@ def train_model( ) if not save_model_as == 'same as source model': run_cmd += f' --save_model_as={save_model_as}' - if not resume == '': - run_cmd += f' --resume={resume}' + # if not resume == '': + # run_cmd += f' --resume={resume}' if not float(prior_loss_weight) == 1.0: run_cmd += f' --prior_loss_weight={prior_loss_weight}' if int(clip_skip) > 1: @@ -384,7 +384,13 @@ def train_model( run_cmd += f' --max_train_epochs="{max_train_epochs}"' if not max_data_loader_n_workers == '': run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"' - run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers) + run_cmd += run_cmd_advanced_training( + max_train_epochs=max_train_epochs, + max_data_loader_n_workers=max_data_loader_n_workers, + max_token_length=max_token_length, + resume=resume, + save_state=save_state, + ) print(run_cmd) # Run the command @@ -681,9 +687,6 @@ def dreambooth_tab( label='Shuffle caption', value=False ) with gr.Row(): - save_state = gr.Checkbox( - label='Save training state', value=False - ) color_aug = gr.Checkbox( label='Color augmentation', value=False ) @@ -697,12 +700,6 @@ def dreambooth_tab( label='Clip skip', value='1', minimum=1, maximum=12, step=1 ) with gr.Row(): - resume = gr.Textbox( - label='Resume from saved training state', - placeholder='path to "last-state" state folder to resume from', - ) - resume_button = gr.Button('📂', elem_id='open_folder_small') - resume_button.click(get_folder_path, outputs=resume) prior_loss_weight = gr.Number( label='Prior loss weight', value=1.0 ) @@ -712,25 +709,7 @@ def dreambooth_tab( ) vae_button = gr.Button('📂', elem_id='open_folder_small') vae_button.click(get_any_file_path, outputs=vae) - max_token_length = gr.Dropdown( - label='Max Token Length', - choices=[ - '75', - '150', - '225', - ], - value='75', - ) - max_train_epochs, max_data_loader_n_workers = gradio_advanced_training() - # with gr.Row(): - # max_train_epochs = gr.Textbox( - # label='Max train epoch', - # placeholder='(Optional) Override number of epoch', - # ) - # max_data_loader_n_workers = gr.Textbox( - # label='Max num workers for DataLoader', - # placeholder='(Optional) Override number of epoch. Default: 8', - # ) + save_state, resume, max_token_length, max_train_epochs, max_data_loader_n_workers = gradio_advanced_training() with gr.Tab('Tools'): gr.Markdown( 'This section provide Dreambooth tools to help setup your dataset...' diff --git a/finetune_gui.py b/finetune_gui.py index 2459058..8fae590 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -335,15 +335,21 @@ def train_model( run_cmd += f' --clip_skip={str(clip_skip)}' if int(gradient_accumulation_steps) > 1: run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' - if save_state: - run_cmd += ' --save_state' - if not resume == '': - run_cmd += f' --resume={resume}' + # if save_state: + # run_cmd += ' --save_state' + # if not resume == '': + # run_cmd += f' --resume={resume}' if not output_name == '': run_cmd += f' --output_name="{output_name}"' if (int(max_token_length) > 75): run_cmd += f' --max_token_length={max_token_length}' - run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers) + run_cmd += run_cmd_advanced_training( + max_train_epochs=max_train_epochs, + max_data_loader_n_workers=max_data_loader_n_workers, + max_token_length=max_token_length, + resume=resume, + save_state=save_state, + ) print(run_cmd) # Run the command @@ -640,31 +646,13 @@ def finetune_tab(): label='Shuffle caption', value=False ) with gr.Row(): - save_state = gr.Checkbox( - label='Save training state', value=False - ) - resume = gr.Textbox( - label='Resume from saved training state', - placeholder='path to "last-state" state folder to resume from', - ) - resume_button = gr.Button('📂', elem_id='open_folder_small') - resume_button.click(get_folder_path, outputs=resume) gradient_checkpointing = gr.Checkbox( label='Gradient checkpointing', value=False ) gradient_accumulation_steps = gr.Number( label='Gradient accumulate steps', value='1' ) - max_token_length = gr.Dropdown( - label='Max Token Length', - choices=[ - '75', - '150', - '225', - ], - value='75', - ) - max_train_epochs, max_data_loader_n_workers = gradio_advanced_training() + save_state, resume, max_token_length, max_train_epochs, max_data_loader_n_workers = gradio_advanced_training() with gr.Box(): with gr.Row(): create_caption = gr.Checkbox( diff --git a/library/common_gui.py b/library/common_gui.py index 167bc00..1d193d5 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -4,10 +4,12 @@ import gradio as gr from easygui import msgbox import shutil + def get_dir_and_file(file_path): dir_path, file_name = os.path.split(file_path) return (dir_path, file_name) + def has_ext_files(directory, extension): # Iterate through all the files in the directory for file in os.listdir(directory): @@ -17,18 +19,26 @@ def has_ext_files(directory, extension): # If no extension files were found, return False return False -def get_file_path(file_path='', defaultextension='.json', extension_name='Config files'): + +def get_file_path( + file_path='', defaultextension='.json', extension_name='Config files' +): current_file_path = file_path # print(f'current file path: {current_file_path}') - + initial_dir, initial_file = get_dir_and_file(file_path) root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() file_path = filedialog.askopenfilename( - filetypes=((f'{extension_name}', f'{defaultextension}'), ('All files', '*')), - defaultextension=defaultextension, initialfile=initial_file, initialdir=initial_dir + filetypes=( + (f'{extension_name}', f'{defaultextension}'), + ('All files', '*'), + ), + defaultextension=defaultextension, + initialfile=initial_file, + initialdir=initial_dir, ) root.destroy() @@ -37,17 +47,20 @@ def get_file_path(file_path='', defaultextension='.json', extension_name='Config return file_path + def get_any_file_path(file_path=''): current_file_path = file_path # print(f'current file path: {current_file_path}') - + initial_dir, initial_file = get_dir_and_file(file_path) root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() - file_path = filedialog.askopenfilename(initialdir=initial_dir, - initialfile=initial_file,) + file_path = filedialog.askopenfilename( + initialdir=initial_dir, + initialfile=initial_file, + ) root.destroy() if file_path == '': @@ -65,7 +78,7 @@ def remove_doublequote(file_path): def get_folder_path(folder_path=''): current_folder_path = folder_path - + initial_dir, initial_file = get_dir_and_file(folder_path) root = Tk() @@ -80,17 +93,22 @@ def get_folder_path(folder_path=''): return folder_path -def get_saveasfile_path(file_path='', defaultextension='.json', extension_name='Config files'): +def get_saveasfile_path( + file_path='', defaultextension='.json', extension_name='Config files' +): current_file_path = file_path # print(f'current file path: {current_file_path}') - + initial_dir, initial_file = get_dir_and_file(file_path) root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() save_file_path = filedialog.asksaveasfile( - filetypes=((f'{extension_name}', f'{defaultextension}'), ('All files', '*')), + filetypes=( + (f'{extension_name}', f'{defaultextension}'), + ('All files', '*'), + ), defaultextension=defaultextension, initialdir=initial_dir, initialfile=initial_file, @@ -109,16 +127,20 @@ def get_saveasfile_path(file_path='', defaultextension='.json', extension_name=' return file_path -def get_saveasfilename_path(file_path='', extensions='*', extension_name='Config files'): + +def get_saveasfilename_path( + file_path='', extensions='*', extension_name='Config files' +): current_file_path = file_path # print(f'current file path: {current_file_path}') - + initial_dir, initial_file = get_dir_and_file(file_path) root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() - save_file_path = filedialog.asksaveasfilename(filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')), + save_file_path = filedialog.asksaveasfilename( + filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')), defaultextension=extensions, initialdir=initial_dir, initialfile=initial_file, @@ -138,9 +160,11 @@ def add_pre_postfix( folder='', prefix='', postfix='', caption_file_ext='.caption' ): if not has_ext_files(folder, caption_file_ext): - msgbox(f'No files with extension {caption_file_ext} were found in {folder}...') + msgbox( + f'No files with extension {caption_file_ext} were found in {folder}...' + ) return - + if prefix == '' and postfix == '': return @@ -157,15 +181,16 @@ def add_pre_postfix( f.seek(0, 0) f.write(f'{prefix}{content}{postfix}') f.close() - -def find_replace( - folder='', caption_file_ext='.caption', find='', replace='' -): + + +def find_replace(folder='', caption_file_ext='.caption', find='', replace=''): print('Running caption find/replace') if not has_ext_files(folder, caption_file_ext): - msgbox(f'No files with extension {caption_file_ext} were found in {folder}...') + msgbox( + f'No files with extension {caption_file_ext} were found in {folder}...' + ) return - + if find == '': return @@ -179,13 +204,17 @@ def find_replace( f.write(content) f.close() + def color_aug_changed(color_aug): if color_aug: - msgbox('Disabling "Cache latent" because "Color augmentation" has been selected...') + msgbox( + 'Disabling "Cache latent" because "Color augmentation" has been selected...' + ) return gr.Checkbox.update(value=False, interactive=False) else: return gr.Checkbox.update(value=True, interactive=True) - + + def save_inference_file(output_dir, v2, v_parameterization, output_name): # List all files in the directory files = os.listdir(output_dir) @@ -198,21 +227,26 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name): if os.path.isfile(os.path.join(output_dir, file)): # Split the file name and extension file_name, ext = os.path.splitext(file) - + # Copy the v2-inference-v.yaml file to the current file, with a .yaml extension if v2 and v_parameterization: - print(f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml') + print( + f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml' + ) shutil.copy( f'./v2_inference/v2-inference-v.yaml', f'{output_dir}/{file_name}.yaml', ) elif v2: - print(f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml') + print( + f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml' + ) shutil.copy( f'./v2_inference/v2-inference.yaml', f'{output_dir}/{file_name}.yaml', ) + def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): # define a list of substrings to search for substrings_v2 = [ @@ -262,30 +296,63 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): v_parameterization = False return value, v2, v_parameterization - + ### ### Gradio common GUI section ### - + + def gradio_advanced_training(): with gr.Row(): - max_train_epochs = gr.Textbox( - label='Max train epoch', - placeholder='(Optional) Override number of epoch', - ) - max_data_loader_n_workers = gr.Textbox( - label='Max num workers for DataLoader', - placeholder='(Optional) Override number of epoch. Default: 8', - ) - return max_train_epochs, max_data_loader_n_workers + save_state = gr.Checkbox(label='Save training state', value=False) + resume = gr.Textbox( + label='Resume from saved training state', + placeholder='path to "last-state" state folder to resume from', + ) + resume_button = gr.Button('📂', elem_id='open_folder_small') + resume_button.click(get_folder_path, outputs=resume) + max_token_length = gr.Dropdown( + label='Max Token Length', + choices=[ + '75', + '150', + '225', + ], + value='75', + ) + with gr.Row(): + max_train_epochs = gr.Textbox( + label='Max train epoch', + placeholder='(Optional) Override number of epoch', + ) + max_data_loader_n_workers = gr.Textbox( + label='Max num workers for DataLoader', + placeholder='(Optional) Override number of epoch. Default: 8', + ) + return ( + save_state, + resume, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + ) + def run_cmd_advanced_training(**kwargs): - run_cmd = '' - max_train_epochs = kwargs.get('max_train_epochs', '') - max_data_loader_n_workers = kwargs.get('max_data_loader_n_workers', '') - if not max_train_epochs == '': - run_cmd += f' --max_train_epochs="{max_train_epochs}"' - if not max_data_loader_n_workers == '': - run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"' - - return run_cmd \ No newline at end of file + options = [ + f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"' + if kwargs.get('max_train_epochs') + else '', + f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"' + if kwargs.get('max_data_loader_n_workers') + else '', + f' --max_token_length={kwargs.get("max_token_length", "")}' + if int(kwargs.get('max_token_length', 0)) > 75 + else '', + f' --resume="{kwargs.get("resume", "")}"' + if kwargs.get('resume') + else '', + ' --save_state' if kwargs.get('save_state') else '', + ] + run_cmd = ''.join(options) + return run_cmd diff --git a/lora_gui.py b/lora_gui.py index 676f14d..2432b7a 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -19,7 +19,9 @@ from library.common_gui import ( get_saveasfile_path, color_aug_changed, save_inference_file, - set_pretrained_model_name_or_path_input, gradio_advanced_training,run_cmd_advanced_training, + set_pretrained_model_name_or_path_input, + gradio_advanced_training, + run_cmd_advanced_training, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -172,7 +174,7 @@ def open_configuration( ): # Get list of function parameters and values parameters = list(locals().items()) - + original_file_path = file_path file_path = get_file_path(file_path) @@ -180,11 +182,11 @@ def open_configuration( # load variables from JSON file with open(file_path, 'r') as f: my_data = json.load(f) - print("Loading config...") + print('Loading config...') else: file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action my_data = {} - + values = [file_path] for key, value in parameters: # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found @@ -235,7 +237,7 @@ def train_model( gradient_accumulation_steps, mem_eff_attn, output_name, - model_list, # Keep this. Yes, it is unused here but required given the common list used + model_list, # Keep this. Yes, it is unused here but required given the common list used max_token_length, max_train_epochs, max_data_loader_n_workers, @@ -350,8 +352,8 @@ def train_model( run_cmd += ' --xformers' if shuffle_caption: run_cmd += ' --shuffle_caption' - if save_state: - run_cmd += ' --save_state' + # if save_state: + # run_cmd += ' --save_state' if color_aug: run_cmd += ' --color_aug' if flip_aug: @@ -386,8 +388,8 @@ def train_model( ) if not save_model_as == 'same as source model': run_cmd += f' --save_model_as={save_model_as}' - if not resume == '': - run_cmd += f' --resume="{resume}"' + # if not resume == '': + # run_cmd += f' --resume="{resume}"' if not float(prior_loss_weight) == 1.0: run_cmd += f' --prior_loss_weight={prior_loss_weight}' run_cmd += f' --network_module=networks.lora' @@ -414,9 +416,15 @@ def train_model( # run_cmd += f' --vae="{vae}"' if not output_name == '': run_cmd += f' --output_name="{output_name}"' - if (int(max_token_length) > 75): - run_cmd += f' --max_token_length={max_token_length}' - run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers) + # if (int(max_token_length) > 75): + # run_cmd += f' --max_token_length={max_token_length}' + run_cmd += run_cmd_advanced_training( + max_train_epochs=max_train_epochs, + max_data_loader_n_workers=max_data_loader_n_workers, + max_token_length=max_token_length, + resume=resume, + save_state=save_state, + ) print(run_cmd) # Run the command @@ -564,9 +572,7 @@ def lora_tab( label='Image folder', placeholder='Folder where the training folders containing the images are located', ) - train_data_dir_folder = gr.Button( - '📂', elem_id='open_folder_small' - ) + train_data_dir_folder = gr.Button('📂', elem_id='open_folder_small') train_data_dir_folder.click( get_folder_path, outputs=train_data_dir ) @@ -574,33 +580,21 @@ def lora_tab( label='Regularisation folder', placeholder='(Optional) Folder where where the regularization folders containing the images are located', ) - reg_data_dir_folder = gr.Button( - '📂', elem_id='open_folder_small' - ) - reg_data_dir_folder.click( - get_folder_path, outputs=reg_data_dir - ) + reg_data_dir_folder = gr.Button('📂', elem_id='open_folder_small') + reg_data_dir_folder.click(get_folder_path, outputs=reg_data_dir) with gr.Row(): output_dir = gr.Textbox( label='Output folder', placeholder='Folder to output trained model', ) - output_dir_folder = gr.Button( - '📂', elem_id='open_folder_small' - ) - output_dir_folder.click( - get_folder_path, outputs=output_dir - ) + output_dir_folder = gr.Button('📂', elem_id='open_folder_small') + output_dir_folder.click(get_folder_path, outputs=output_dir) logging_dir = gr.Textbox( label='Logging folder', placeholder='Optional: enable logging and output TensorBoard log to this folder', ) - logging_dir_folder = gr.Button( - '📂', elem_id='open_folder_small' - ) - logging_dir_folder.click( - get_folder_path, outputs=logging_dir - ) + logging_dir_folder = gr.Button('📂', elem_id='open_folder_small') + logging_dir_folder.click(get_folder_path, outputs=logging_dir) with gr.Row(): output_name = gr.Textbox( label='Model output name', @@ -659,11 +653,13 @@ def lora_tab( with gr.Row(): text_encoder_lr = gr.Textbox( label='Text Encoder learning rate', - value="5e-5", + value='5e-5', placeholder='Optional', ) unet_lr = gr.Textbox( - label='Unet learning rate', value="1e-3", placeholder='Optional' + label='Unet learning rate', + value='1e-3', + placeholder='Optional', ) network_dim = gr.Slider( minimum=1, @@ -731,13 +727,9 @@ def lora_tab( label='Stop text encoder training', ) with gr.Row(): - enable_bucket = gr.Checkbox( - label='Enable buckets', value=True - ) + enable_bucket = gr.Checkbox(label='Enable buckets', value=True) cache_latent = gr.Checkbox(label='Cache latent', value=True) - use_8bit_adam = gr.Checkbox( - label='Use 8bit adam', value=True - ) + use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) xformers = gr.Checkbox(label='Use xformers', value=True) with gr.Accordion('Advanced Configuration', open=False): with gr.Row(): @@ -777,33 +769,14 @@ def lora_tab( mem_eff_attn = gr.Checkbox( label='Memory efficient attention', value=False ) - with gr.Row(): - save_state = gr.Checkbox( - label='Save training state', value=False - ) - resume = gr.Textbox( - label='Resume from saved training state', - placeholder='path to "last-state" state folder to resume from', - ) - resume_button = gr.Button('📂', elem_id='open_folder_small') - resume_button.click(get_folder_path, outputs=resume) - # vae = gr.Textbox( - # label='VAE', - # placeholder='(Optiona) path to checkpoint of vae to replace for training', - # ) - # vae_button = gr.Button('📂', elem_id='open_folder_small') - # vae_button.click(get_any_file_path, outputs=vae) - max_token_length = gr.Dropdown( - label='Max Token Length', - choices=[ - '75', - '150', - '225', - ], - value='75', - ) - max_train_epochs, max_data_loader_n_workers = gradio_advanced_training() - + ( + save_state, + resume, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + ) = gradio_advanced_training() + with gr.Tab('Tools'): gr.Markdown( 'This section provide Dreambooth tools to help setup your dataset...'