diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 94352de..97eb609 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -70,7 +70,11 @@ def save_configuration( flip_aug, clip_skip, vae, + output_name, ): + # Get list of function parameters and values + parameters = list(locals().items()) + original_file_path = file_path save_as_bool = True if save_as.get('label') == 'True' else False @@ -90,47 +94,57 @@ def save_configuration( # Return the values of the variables as a dictionary variables = { - 'pretrained_model_name_or_path': pretrained_model_name_or_path, - 'v2': v2, - 'v_parameterization': v_parameterization, - 'logging_dir': logging_dir, - 'train_data_dir': train_data_dir, - 'reg_data_dir': reg_data_dir, - 'output_dir': output_dir, - 'max_resolution': max_resolution, - 'learning_rate': learning_rate, - 'lr_scheduler': lr_scheduler, - 'lr_warmup': lr_warmup, - 'train_batch_size': train_batch_size, - 'epoch': epoch, - 'save_every_n_epochs': save_every_n_epochs, - 'mixed_precision': mixed_precision, - 'save_precision': save_precision, - 'seed': seed, - 'num_cpu_threads_per_process': num_cpu_threads_per_process, - 'cache_latent': cache_latent, - 'caption_extention': caption_extention, - 'enable_bucket': enable_bucket, - 'gradient_checkpointing': gradient_checkpointing, - 'full_fp16': full_fp16, - 'no_token_padding': no_token_padding, - 'stop_text_encoder_training': stop_text_encoder_training, - 'use_8bit_adam': use_8bit_adam, - 'xformers': xformers, - 'save_model_as': save_model_as, - 'shuffle_caption': shuffle_caption, - 'save_state': save_state, - 'resume': resume, - 'prior_loss_weight': prior_loss_weight, - 'color_aug': color_aug, - 'flip_aug': flip_aug, - 'clip_skip': clip_skip, - 'vae': vae, + name: value + for name, value in parameters # locals().items() + if name + not in [ + 'file_path', + 'save_as', + ] } + # variables = { + # 'pretrained_model_name_or_path': pretrained_model_name_or_path, + # 'v2': v2, + # 'v_parameterization': v_parameterization, + # 'logging_dir': logging_dir, + # 'train_data_dir': train_data_dir, + # 'reg_data_dir': reg_data_dir, + # 'output_dir': output_dir, + # 'max_resolution': max_resolution, + # 'learning_rate': learning_rate, + # 'lr_scheduler': lr_scheduler, + # 'lr_warmup': lr_warmup, + # 'train_batch_size': train_batch_size, + # 'epoch': epoch, + # 'save_every_n_epochs': save_every_n_epochs, + # 'mixed_precision': mixed_precision, + # 'save_precision': save_precision, + # 'seed': seed, + # 'num_cpu_threads_per_process': num_cpu_threads_per_process, + # 'cache_latent': cache_latent, + # 'caption_extention': caption_extention, + # 'enable_bucket': enable_bucket, + # 'gradient_checkpointing': gradient_checkpointing, + # 'full_fp16': full_fp16, + # 'no_token_padding': no_token_padding, + # 'stop_text_encoder_training': stop_text_encoder_training, + # 'use_8bit_adam': use_8bit_adam, + # 'xformers': xformers, + # 'save_model_as': save_model_as, + # 'shuffle_caption': shuffle_caption, + # 'save_state': save_state, + # 'resume': resume, + # 'prior_loss_weight': prior_loss_weight, + # 'color_aug': color_aug, + # 'flip_aug': flip_aug, + # 'clip_skip': clip_skip, + # 'vae': vae, + # 'output_name': output_name, + # } # Save the data to the selected file with open(file_path, 'w') as file: - json.dump(variables, file) + json.dump(variables, file, indent=2) return file_path @@ -173,7 +187,10 @@ def open_configuration( flip_aug, clip_skip, vae, + output_name, ): + # Get list of function parameters and values + parameters = list(locals().items()) original_file_path = file_path file_path = get_file_path(file_path) @@ -187,50 +204,59 @@ def open_configuration( file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action my_data = {} + values = [file_path] + for key, value in parameters: + # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found + if not key in ['file_path']: + values.append(my_data.get(key, value)) + # print(values) + return tuple(values) + # Return the values of the variables as a dictionary - return ( - file_path, - my_data.get( - 'pretrained_model_name_or_path', pretrained_model_name_or_path - ), - my_data.get('v2', v2), - my_data.get('v_parameterization', v_parameterization), - my_data.get('logging_dir', logging_dir), - my_data.get('train_data_dir', train_data_dir), - my_data.get('reg_data_dir', reg_data_dir), - my_data.get('output_dir', output_dir), - my_data.get('max_resolution', max_resolution), - my_data.get('learning_rate', learning_rate), - my_data.get('lr_scheduler', lr_scheduler), - my_data.get('lr_warmup', lr_warmup), - my_data.get('train_batch_size', train_batch_size), - my_data.get('epoch', epoch), - my_data.get('save_every_n_epochs', save_every_n_epochs), - my_data.get('mixed_precision', mixed_precision), - my_data.get('save_precision', save_precision), - my_data.get('seed', seed), - my_data.get( - 'num_cpu_threads_per_process', num_cpu_threads_per_process - ), - my_data.get('cache_latent', cache_latent), - my_data.get('caption_extention', caption_extention), - my_data.get('enable_bucket', enable_bucket), - my_data.get('gradient_checkpointing', gradient_checkpointing), - my_data.get('full_fp16', full_fp16), - my_data.get('no_token_padding', no_token_padding), - my_data.get('stop_text_encoder_training', stop_text_encoder_training), - my_data.get('use_8bit_adam', use_8bit_adam), - my_data.get('xformers', xformers), - my_data.get('save_model_as', save_model_as), - my_data.get('shuffle_caption', shuffle_caption), - my_data.get('save_state', save_state), - my_data.get('resume', resume), - my_data.get('prior_loss_weight', prior_loss_weight), - my_data.get('color_aug', color_aug), - my_data.get('flip_aug', flip_aug), - my_data.get('clip_skip', clip_skip), - my_data.get('vae', vae), - ) + # return ( + # file_path, + # my_data.get( + # 'pretrained_model_name_or_path', pretrained_model_name_or_path + # ), + # my_data.get('v2', v2), + # my_data.get('v_parameterization', v_parameterization), + # my_data.get('logging_dir', logging_dir), + # my_data.get('train_data_dir', train_data_dir), + # my_data.get('reg_data_dir', reg_data_dir), + # my_data.get('output_dir', output_dir), + # my_data.get('max_resolution', max_resolution), + # my_data.get('learning_rate', learning_rate), + # my_data.get('lr_scheduler', lr_scheduler), + # my_data.get('lr_warmup', lr_warmup), + # my_data.get('train_batch_size', train_batch_size), + # my_data.get('epoch', epoch), + # my_data.get('save_every_n_epochs', save_every_n_epochs), + # my_data.get('mixed_precision', mixed_precision), + # my_data.get('save_precision', save_precision), + # my_data.get('seed', seed), + # my_data.get( + # 'num_cpu_threads_per_process', num_cpu_threads_per_process + # ), + # my_data.get('cache_latent', cache_latent), + # my_data.get('caption_extention', caption_extention), + # my_data.get('enable_bucket', enable_bucket), + # my_data.get('gradient_checkpointing', gradient_checkpointing), + # my_data.get('full_fp16', full_fp16), + # my_data.get('no_token_padding', no_token_padding), + # my_data.get('stop_text_encoder_training', stop_text_encoder_training), + # my_data.get('use_8bit_adam', use_8bit_adam), + # my_data.get('xformers', xformers), + # my_data.get('save_model_as', save_model_as), + # my_data.get('shuffle_caption', shuffle_caption), + # my_data.get('save_state', save_state), + # my_data.get('resume', resume), + # my_data.get('prior_loss_weight', prior_loss_weight), + # my_data.get('color_aug', color_aug), + # my_data.get('flip_aug', flip_aug), + # my_data.get('clip_skip', clip_skip), + # my_data.get('vae', vae), + # my_data.get('output_name', output_name), + # ) def train_model( @@ -270,21 +296,30 @@ def train_model( flip_aug, clip_skip, vae, + output_name, ): - def save_inference_file(output_dir, v2, v_parameterization): - # Copy inference model for v2 if required - if v2 and v_parameterization: - print(f'Saving v2-inference-v.yaml as {output_dir}/last.yaml') - shutil.copy( - f'./v2_inference/v2-inference-v.yaml', - f'{output_dir}/last.yaml', - ) - elif v2: - print(f'Saving v2-inference.yaml as {output_dir}/last.yaml') - shutil.copy( - f'./v2_inference/v2-inference.yaml', - f'{output_dir}/last.yaml', - ) + def save_inference_file(output_dir, v2, v_parameterization, output_name): + # List all files in the directory + files = os.listdir(output_dir) + + # Iterate over the list of files + for file in files: + # Check if the file starts with the value of save_inference_file + if file.startswith(output_name): + # Copy the v2-inference-v.yaml file to the current file, with a .yaml extension + if v2 and v_parameterization: + print(f'Saving v2-inference-v.yaml as {output_dir}/{file}.yaml') + shutil.copy( + f'./v2_inference/v2-inference-v.yaml', + f'{output_dir}/{file}.yaml', + ) + elif v2: + print(f'Saving v2-inference.yaml as {output_dir}/{file}.yaml') + shutil.copy( + f'./v2_inference/v2-inference.yaml', + f'{output_dir}/{file}.yaml', + ) + if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -437,17 +472,19 @@ def train_model( run_cmd += f' --clip_skip={str(clip_skip)}' if not vae == '': run_cmd += f' --vae="{vae}"' + if not output_name == '': + run_cmd += f' --output_name="{output_name}"' print(run_cmd) # Run the command subprocess.run(run_cmd) # check if output_dir/last is a folder... therefore it is a diffuser model - last_dir = pathlib.Path(f'{output_dir}/last') + last_dir = pathlib.Path(f'{output_dir}/{output_name}') if not last_dir.is_dir(): # Copy inference model for v2 if required - save_inference_file(output_dir, v2, v_parameterization) + save_inference_file(output_dir, v2, v_parameterization, output_name) def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): @@ -656,9 +693,7 @@ def dreambooth_tab( output_dir_input_folder = gr.Button( '📂', elem_id='open_folder_small' ) - output_dir_input_folder.click( - get_folder_path, outputs=output_dir - ) + output_dir_input_folder.click(get_folder_path, outputs=output_dir) logging_dir = gr.Textbox( label='Logging folder', placeholder='Optional: enable logging and output TensorBoard log to this folder', @@ -669,6 +704,13 @@ def dreambooth_tab( logging_dir_input_folder.click( get_folder_path, outputs=logging_dir ) + with gr.Row(): + output_name = gr.Textbox( + label='Model output name', + placeholder='Name of the model to output', + value='last', + interactive=True, + ) train_data_dir.change( remove_doublequote, inputs=[train_data_dir], @@ -763,13 +805,9 @@ def dreambooth_tab( label='Stop text encoder training', ) with gr.Row(): - enable_bucket = gr.Checkbox( - label='Enable buckets', value=True - ) + enable_bucket = gr.Checkbox(label='Enable buckets', value=True) cache_latent = gr.Checkbox(label='Cache latent', value=True) - use_8bit_adam = gr.Checkbox( - label='Use 8bit adam', value=True - ) + use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) xformers = gr.Checkbox(label='Use xformers', value=True) with gr.Accordion('Advanced Configuration', open=False): with gr.Row(): @@ -831,7 +869,7 @@ def dreambooth_tab( ) button_run = gr.Button('Train model') - + settings_list = [ pretrained_model_name_or_path, v2, @@ -869,6 +907,7 @@ def dreambooth_tab( flip_aug, clip_skip, vae, + output_name, ] button_open_config.click(