from tkinter import filedialog, Tk from easygui import msgbox import os import gradio as gr import easygui import shutil folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 # define a list of substrings to search for v2 base models V2_BASE_MODELS = [ 'stabilityai/stable-diffusion-2-1-base', 'stabilityai/stable-diffusion-2-base', ] # define a list of substrings to search for v_parameterization models V_PARAMETERIZATION_MODELS = [ 'stabilityai/stable-diffusion-2-1', 'stabilityai/stable-diffusion-2', ] # define a list of substrings to v1.x models V1_MODELS = [ 'CompVis/stable-diffusion-v1-4', 'runwayml/stable-diffusion-v1-5', ] # define a list of substrings to search for ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID'] def check_if_model_exist(output_name, output_dir, save_model_as): if save_model_as in ['diffusers', 'diffusers_safetendors']: ckpt_folder = os.path.join(output_dir, output_name) if os.path.isdir(ckpt_folder): msg = f'A diffuser model with the same name {ckpt_folder} already exists. Do you want to overwrite it?' if not easygui.ynbox(msg, 'Overwrite Existing Model?'): print( 'Aborting training due to existing model with same name...' ) return True elif save_model_as in ['ckpt', 'safetensors']: ckpt_file = os.path.join(output_dir, output_name + '.' + save_model_as) if os.path.isfile(ckpt_file): msg = f'A model with the same file name {ckpt_file} already exists. Do you want to overwrite it?' if not easygui.ynbox(msg, 'Overwrite Existing Model?'): print( 'Aborting training due to existing model with same name...' ) return True else: print( 'Can\'t verify if existing model exist when save model is set a "same as source model", continuing to train model...' ) return False return False def update_my_data(my_data): # Update the optimizer based on the use_8bit_adam flag use_8bit_adam = my_data.get('use_8bit_adam', False) my_data.setdefault('optimizer', 'AdamW8bit' if use_8bit_adam else 'AdamW') # Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model model_list = my_data.get('model_list', []) pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '') if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS: my_data['model_list'] = 'custom' # Convert epoch and save_every_n_epochs values to int if they are strings for key in ['epoch', 'save_every_n_epochs']: value = my_data.get(key, -1) if isinstance(value, str) and value.isdigit(): my_data[key] = int(value) elif not value: my_data[key] = -1 # Update LoRA_type if it is set to LoCon if my_data.get('LoRA_type', 'Standard') == 'LoCon': my_data['LoRA_type'] = 'LyCORIS/LoCon' # Update model save choices due to changes for LoRA and TI training if ( (my_data.get('LoRA_type') or my_data.get('num_vectors_per_token')) and my_data.get('save_model_as') not in ['safetensors', 'ckpt'] ): message = ( 'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}' ) if my_data.get('LoRA_type'): print(message.format('LoRA')) if my_data.get('num_vectors_per_token'): print(message.format('TI')) my_data['save_model_as'] = 'safetensors' return my_data def get_dir_and_file(file_path): dir_path, file_name = os.path.split(file_path) return (dir_path, file_name) # def has_ext_files(directory, extension): # # Iterate through all the files in the directory # for file in os.listdir(directory): # # If the file name ends with extension, return True # if file.endswith(extension): # return True # # If no extension files were found, return False # return False def get_file_path( file_path='', default_extension='.json', extension_name='Config files' ): if not any(var in os.environ for var in FILE_ENV_EXCLUSION): current_file_path = file_path # print(f'current file path: {current_file_path}') initial_dir, initial_file = get_dir_and_file(file_path) # Create a hidden Tkinter root window root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() # Show the open file dialog and get the selected file path file_path = filedialog.askopenfilename( filetypes=( (extension_name, f'*{default_extension}'), ('All files', '*.*'), ), defaultextension=default_extension, initialfile=initial_file, initialdir=initial_dir, ) # Destroy the hidden root window root.destroy() # If no file is selected, use the current file path if not file_path: file_path = current_file_path current_file_path = file_path # print(f'current file path: {current_file_path}') return file_path def get_any_file_path(file_path=''): if not any(var in os.environ for var in FILE_ENV_EXCLUSION): current_file_path = file_path # print(f'current file path: {current_file_path}') initial_dir, initial_file = get_dir_and_file(file_path) root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() file_path = filedialog.askopenfilename( initialdir=initial_dir, initialfile=initial_file, ) root.destroy() if file_path == '': file_path = current_file_path return file_path def remove_doublequote(file_path): if file_path != None: file_path = file_path.replace('"', '') return file_path # def set_legacy_8bitadam(optimizer, use_8bit_adam): # if optimizer == 'AdamW8bit': # # use_8bit_adam = True # return gr.Dropdown.update(value=optimizer), gr.Checkbox.update( # value=True, interactive=False, visible=True # ) # else: # # use_8bit_adam = False # return gr.Dropdown.update(value=optimizer), gr.Checkbox.update( # value=False, interactive=False, visible=True # ) def get_folder_path(folder_path=''): if not any(var in os.environ for var in FILE_ENV_EXCLUSION): current_folder_path = folder_path initial_dir, initial_file = get_dir_and_file(folder_path) root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() folder_path = filedialog.askdirectory(initialdir=initial_dir) root.destroy() if folder_path == '': folder_path = current_folder_path return folder_path def get_saveasfile_path( file_path='', defaultextension='.json', extension_name='Config files' ): if not any(var in os.environ for var in FILE_ENV_EXCLUSION): current_file_path = file_path # print(f'current file path: {current_file_path}') initial_dir, initial_file = get_dir_and_file(file_path) root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() save_file_path = filedialog.asksaveasfile( filetypes=( (f'{extension_name}', f'{defaultextension}'), ('All files', '*'), ), defaultextension=defaultextension, initialdir=initial_dir, initialfile=initial_file, ) root.destroy() # print(save_file_path) if save_file_path == None: file_path = current_file_path else: print(save_file_path.name) file_path = save_file_path.name # print(file_path) return file_path def get_saveasfilename_path( file_path='', extensions='*', extension_name='Config files' ): if not any(var in os.environ for var in FILE_ENV_EXCLUSION): current_file_path = file_path # print(f'current file path: {current_file_path}') initial_dir, initial_file = get_dir_and_file(file_path) root = Tk() root.wm_attributes('-topmost', 1) root.withdraw() save_file_path = filedialog.asksaveasfilename( filetypes=((f'{extension_name}', f'{extensions}'), ('All files', '*')), defaultextension=extensions, initialdir=initial_dir, initialfile=initial_file, ) root.destroy() if save_file_path == '': file_path = current_file_path else: # print(save_file_path) file_path = save_file_path return file_path def add_pre_postfix( folder: str = '', prefix: str = '', postfix: str = '', caption_file_ext: str = '.caption', ) -> None: """ Add prefix and/or postfix to the content of caption files within a folder. If no caption files are found, create one with the requested prefix and/or postfix. Args: folder (str): Path to the folder containing caption files. prefix (str, optional): Prefix to add to the content of the caption files. postfix (str, optional): Postfix to add to the content of the caption files. caption_file_ext (str, optional): Extension of the caption files. """ if prefix == '' and postfix == '': return image_extensions = ('.jpg', '.jpeg', '.png', '.webp') image_files = [ f for f in os.listdir(folder) if f.lower().endswith(image_extensions) ] for image_file in image_files: caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext caption_file_path = os.path.join(folder, caption_file_name) if not os.path.exists(caption_file_path): with open(caption_file_path, 'w') as f: separator = ' ' if prefix and postfix else '' f.write(f'{prefix}{separator}{postfix}') else: with open(caption_file_path, 'r+') as f: content = f.read() content = content.rstrip() f.seek(0, 0) prefix_separator = ' ' if prefix else '' postfix_separator = ' ' if postfix else '' f.write( f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}' ) def has_ext_files(folder_path: str, file_extension: str) -> bool: """ Check if there are any files with the specified extension in the given folder. Args: folder_path (str): Path to the folder containing files. file_extension (str): Extension of the files to look for. Returns: bool: True if files with the specified extension are found, False otherwise. """ for file in os.listdir(folder_path): if file.endswith(file_extension): return True return False def find_replace( folder_path: str = '', caption_file_ext: str = '.caption', search_text: str = '', replace_text: str = '', ) -> None: """ Find and replace text in caption files within a folder. Args: folder_path (str, optional): Path to the folder containing caption files. caption_file_ext (str, optional): Extension of the caption files. search_text (str, optional): Text to search for in the caption files. replace_text (str, optional): Text to replace the search text with. """ print('Running caption find/replace') if not has_ext_files(folder_path, caption_file_ext): msgbox( f'No files with extension {caption_file_ext} were found in {folder_path}...' ) return if search_text == '': return caption_files = [ f for f in os.listdir(folder_path) if f.endswith(caption_file_ext) ] for caption_file in caption_files: with open( os.path.join(folder_path, caption_file), 'r', errors='ignore' ) as f: content = f.read() content = content.replace(search_text, replace_text) with open(os.path.join(folder_path, caption_file), 'w') as f: f.write(content) def color_aug_changed(color_aug): if color_aug: msgbox( 'Disabling "Cache latent" because "Color augmentation" has been selected...' ) return gr.Checkbox.update(value=False, interactive=False) else: return gr.Checkbox.update(value=True, interactive=True) def save_inference_file(output_dir, v2, v_parameterization, output_name): # List all files in the directory files = os.listdir(output_dir) # Iterate over the list of files for file in files: # Check if the file starts with the value of output_name if file.startswith(output_name): # Check if it is a file or a directory if os.path.isfile(os.path.join(output_dir, file)): # Split the file name and extension file_name, ext = os.path.splitext(file) # Copy the v2-inference-v.yaml file to the current file, with a .yaml extension if v2 and v_parameterization: print( f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml' ) shutil.copy( f'./v2_inference/v2-inference-v.yaml', f'{output_dir}/{file_name}.yaml', ) elif v2: print( f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml' ) shutil.copy( f'./v2_inference/v2-inference.yaml', f'{output_dir}/{file_name}.yaml', ) def set_pretrained_model_name_or_path_input( model_list, pretrained_model_name_or_path, v2, v_parameterization ): # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list if str(model_list) in V2_BASE_MODELS: print('SD v2 model detected. Setting --v2 parameter') v2 = True v_parameterization = False pretrained_model_name_or_path = str(model_list) # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list if str(model_list) in V_PARAMETERIZATION_MODELS: print( 'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization' ) v2 = True v_parameterization = True pretrained_model_name_or_path = str(model_list) if str(model_list) in V1_MODELS: v2 = False v_parameterization = False pretrained_model_name_or_path = str(model_list) if model_list == 'custom': if ( str(pretrained_model_name_or_path) in V1_MODELS or str(pretrained_model_name_or_path) in V2_BASE_MODELS or str(pretrained_model_name_or_path) in V_PARAMETERIZATION_MODELS ): pretrained_model_name_or_path = '' v2 = False v_parameterization = False return model_list, pretrained_model_name_or_path, v2, v_parameterization def set_v2_checkbox(model_list, v2, v_parameterization): # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list if str(model_list) in V2_BASE_MODELS: v2 = True v_parameterization = False # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list if str(model_list) in V_PARAMETERIZATION_MODELS: v2 = True v_parameterization = True if str(model_list) in V1_MODELS: v2 = False v_parameterization = False return v2, v_parameterization def set_model_list( model_list, pretrained_model_name_or_path, v2, v_parameterization, ): if not pretrained_model_name_or_path in ALL_PRESET_MODELS: model_list = 'custom' else: model_list = pretrained_model_name_or_path return model_list, v2, v_parameterization ### ### Gradio common GUI section ### def gradio_config(): with gr.Accordion('Configuration file', open=False): with gr.Row(): button_open_config = gr.Button('Open 📂', elem_id='open_folder') button_save_config = gr.Button('Save 💾', elem_id='open_folder') button_save_as_config = gr.Button( 'Save as... 💾', elem_id='open_folder' ) config_file_name = gr.Textbox( label='', placeholder="type the configuration file path or use the 'Open' button above to select it...", interactive=True, ) button_load_config = gr.Button('Load 💾', elem_id='open_folder') config_file_name.change( remove_doublequote, inputs=[config_file_name], outputs=[config_file_name], ) return ( button_open_config, button_save_config, button_save_as_config, config_file_name, button_load_config, ) def get_pretrained_model_name_or_path_file( model_list, pretrained_model_name_or_path ): pretrained_model_name_or_path = get_any_file_path( pretrained_model_name_or_path ) set_model_list(model_list, pretrained_model_name_or_path) def gradio_source_model(save_model_as_choices = [ 'same as source model', 'ckpt', 'diffusers', 'diffusers_safetensors', 'safetensors', ]): with gr.Tab('Source model'): # Define the input elements with gr.Row(): pretrained_model_name_or_path = gr.Textbox( label='Pretrained model name or path', placeholder='enter the path to custom model or name of pretrained model', value='runwayml/stable-diffusion-v1-5', ) pretrained_model_name_or_path_file = gr.Button( document_symbol, elem_id='open_folder_small' ) pretrained_model_name_or_path_file.click( get_any_file_path, inputs=pretrained_model_name_or_path, outputs=pretrained_model_name_or_path, show_progress=False, ) pretrained_model_name_or_path_folder = gr.Button( folder_symbol, elem_id='open_folder_small' ) pretrained_model_name_or_path_folder.click( get_folder_path, inputs=pretrained_model_name_or_path, outputs=pretrained_model_name_or_path, show_progress=False, ) model_list = gr.Dropdown( label='Model Quick Pick', choices=[ 'custom', 'stabilityai/stable-diffusion-2-1-base', 'stabilityai/stable-diffusion-2-base', 'stabilityai/stable-diffusion-2-1', 'stabilityai/stable-diffusion-2', 'runwayml/stable-diffusion-v1-5', 'CompVis/stable-diffusion-v1-4', ], value='runwayml/stable-diffusion-v1-5', ) save_model_as = gr.Dropdown( label='Save trained model as', choices=save_model_as_choices, value='safetensors', ) with gr.Row(): v2 = gr.Checkbox(label='v2', value=False) v_parameterization = gr.Checkbox( label='v_parameterization', value=False ) v2.change( set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization], show_progress=False, ) v_parameterization.change( set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization], show_progress=False, ) model_list.change( set_pretrained_model_name_or_path_input, inputs=[ model_list, pretrained_model_name_or_path, v2, v_parameterization, ], outputs=[ model_list, pretrained_model_name_or_path, v2, v_parameterization, ], show_progress=False, ) # Update the model list and parameters when user click outside the button or field pretrained_model_name_or_path.change( set_model_list, inputs=[ model_list, pretrained_model_name_or_path, v2, v_parameterization, ], outputs=[ model_list, v2, v_parameterization, ], show_progress=False, ) return ( pretrained_model_name_or_path, v2, v_parameterization, save_model_as, model_list, ) def gradio_training( learning_rate_value='1e-6', lr_scheduler_value='constant', lr_warmup_value='0', ): with gr.Row(): train_batch_size = gr.Slider( minimum=1, maximum=64, label='Train batch size', value=1, step=1, ) epoch = gr.Number(label='Epoch', value=1, precision=0) save_every_n_epochs = gr.Number( label='Save every N epochs', value=1, precision=0 ) caption_extension = gr.Textbox( label='Caption Extension', placeholder='(Optional) Extension for caption files. default: .caption', ) with gr.Row(): mixed_precision = gr.Dropdown( label='Mixed precision', choices=[ 'no', 'fp16', 'bf16', ], value='fp16', ) save_precision = gr.Dropdown( label='Save precision', choices=[ 'float', 'fp16', 'bf16', ], value='fp16', ) num_cpu_threads_per_process = gr.Slider( minimum=1, maximum=os.cpu_count(), step=1, label='Number of CPU threads per core', value=2, ) seed = gr.Textbox(label='Seed', placeholder='(Optional) eg:1234') cache_latents = gr.Checkbox(label='Cache latent', value=True) with gr.Row(): learning_rate = gr.Textbox( label='Learning rate', value=learning_rate_value ) lr_scheduler = gr.Dropdown( label='LR Scheduler', choices=[ 'adafactor', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'linear', 'polynomial', ], value=lr_scheduler_value, ) lr_warmup = gr.Textbox( label='LR warmup (% of steps)', value=lr_warmup_value ) optimizer = gr.Dropdown( label='Optimizer', choices=[ 'AdamW', 'AdamW8bit', 'Adafactor', 'DAdaptation', 'Lion', 'SGDNesterov', 'SGDNesterov8bit', ], value='AdamW8bit', interactive=True, ) with gr.Row(): optimizer_args = gr.Textbox( label='Optimizer extra arguments', placeholder='(Optional) eg: relative_step=True scale_parameter=True warmup_init=True', ) return ( learning_rate, lr_scheduler, lr_warmup, train_batch_size, epoch, save_every_n_epochs, mixed_precision, save_precision, num_cpu_threads_per_process, seed, caption_extension, cache_latents, optimizer, optimizer_args, ) def run_cmd_training(**kwargs): options = [ f' --learning_rate="{kwargs.get("learning_rate", "")}"' if kwargs.get('learning_rate') else '', f' --lr_scheduler="{kwargs.get("lr_scheduler", "")}"' if kwargs.get('lr_scheduler') else '', f' --lr_warmup_steps="{kwargs.get("lr_warmup_steps", "")}"' if kwargs.get('lr_warmup_steps') else '', f' --train_batch_size="{kwargs.get("train_batch_size", "")}"' if kwargs.get('train_batch_size') else '', f' --max_train_steps="{kwargs.get("max_train_steps", "")}"' if kwargs.get('max_train_steps') else '', f' --save_every_n_epochs="{int(kwargs.get("save_every_n_epochs", 1))}"' if int(kwargs.get('save_every_n_epochs')) else '', f' --mixed_precision="{kwargs.get("mixed_precision", "")}"' if kwargs.get('mixed_precision') else '', f' --save_precision="{kwargs.get("save_precision", "")}"' if kwargs.get('save_precision') else '', f' --seed="{kwargs.get("seed", "")}"' if kwargs.get('seed') != '' else '', f' --caption_extension="{kwargs.get("caption_extension", "")}"' if kwargs.get('caption_extension') else '', ' --cache_latents' if kwargs.get('cache_latents') else '', # ' --use_lion_optimizer' if kwargs.get('optimizer') == 'Lion' else '', f' --optimizer_type="{kwargs.get("optimizer", "AdamW")}"', f' --optimizer_args {kwargs.get("optimizer_args", "")}' if not kwargs.get('optimizer_args') == '' else '', ] run_cmd = ''.join(options) return run_cmd def gradio_advanced_training(): with gr.Row(): additional_parameters = gr.Textbox( label='Additional parameters', placeholder='(Optional) Use to provide additional parameters not handled by the GUI. Eg: --some_parameters "value"', ) with gr.Row(): keep_tokens = gr.Slider( label='Keep n tokens', value='0', minimum=0, maximum=32, step=1 ) clip_skip = gr.Slider( label='Clip skip', value='1', minimum=1, maximum=12, step=1 ) max_token_length = gr.Dropdown( label='Max Token Length', choices=[ '75', '150', '225', ], value='75', ) full_fp16 = gr.Checkbox( label='Full fp16 training (experimental)', value=False ) with gr.Row(): gradient_checkpointing = gr.Checkbox( label='Gradient checkpointing', value=False ) shuffle_caption = gr.Checkbox(label='Shuffle caption', value=False) persistent_data_loader_workers = gr.Checkbox( label='Persistent data loader', value=False ) mem_eff_attn = gr.Checkbox( label='Memory efficient attention', value=False ) with gr.Row(): # This use_8bit_adam element should be removed in a future release as it is no longer used # use_8bit_adam = gr.Checkbox( # label='Use 8bit adam', value=False, visible=False # ) xformers = gr.Checkbox(label='Use xformers', value=True) color_aug = gr.Checkbox(label='Color augmentation', value=False) flip_aug = gr.Checkbox(label='Flip augmentation', value=False) min_snr_gamma = gr.Slider(label='Min SNR gamma', value = 0, minimum=0, maximum=20, step=1) with gr.Row(): bucket_no_upscale = gr.Checkbox( label="Don't upscale bucket resolution", value=True ) bucket_reso_steps = gr.Number( label='Bucket resolution steps', value=64 ) random_crop = gr.Checkbox( label='Random crop instead of center crop', value=False ) noise_offset = gr.Textbox( label='Noise offset (0 - 1)', placeholder='(Oprional) eg: 0.1' ) with gr.Row(): caption_dropout_every_n_epochs = gr.Number( label='Dropout caption every n epochs', value=0 ) caption_dropout_rate = gr.Slider( label='Rate of caption dropout', value=0, minimum=0, maximum=1 ) vae_batch_size = gr.Slider( label='VAE batch size', minimum=0, maximum=32, value=0, step=1 ) with gr.Row(): save_state = gr.Checkbox(label='Save training state', value=False) resume = gr.Textbox( label='Resume from saved training state', placeholder='path to "last-state" state folder to resume from', ) resume_button = gr.Button('📂', elem_id='open_folder_small') resume_button.click( get_folder_path, outputs=resume, show_progress=False, ) max_train_epochs = gr.Textbox( label='Max train epoch', placeholder='(Optional) Override number of epoch', ) max_data_loader_n_workers = gr.Textbox( label='Max num workers for DataLoader', placeholder='(Optional) Override number of epoch. Default: 8', value="0", ) return ( # use_8bit_adam, xformers, full_fp16, gradient_checkpointing, shuffle_caption, color_aug, flip_aug, clip_skip, mem_eff_attn, save_state, resume, max_token_length, max_train_epochs, max_data_loader_n_workers, keep_tokens, persistent_data_loader_workers, bucket_no_upscale, random_crop, bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, noise_offset, additional_parameters, vae_batch_size, min_snr_gamma, ) def run_cmd_advanced_training(**kwargs): options = [ f' --max_train_epochs="{kwargs.get("max_train_epochs", "")}"' if kwargs.get('max_train_epochs') else '', f' --max_data_loader_n_workers="{kwargs.get("max_data_loader_n_workers", "")}"' if kwargs.get('max_data_loader_n_workers') else '', f' --max_token_length={kwargs.get("max_token_length", "")}' if int(kwargs.get('max_token_length', 75)) > 75 else '', f' --clip_skip={kwargs.get("clip_skip", "")}' if int(kwargs.get('clip_skip', 1)) > 1 else '', f' --resume="{kwargs.get("resume", "")}"' if kwargs.get('resume') else '', f' --keep_tokens="{kwargs.get("keep_tokens", "")}"' if int(kwargs.get('keep_tokens', 0)) > 0 else '', f' --caption_dropout_every_n_epochs="{int(kwargs.get("caption_dropout_every_n_epochs", 0))}"' if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0 else '', f' --caption_dropout_every_n_epochs="{int(kwargs.get("caption_dropout_every_n_epochs", 0))}"' if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0 else '', f' --vae_batch_size="{kwargs.get("vae_batch_size", 0)}"' if int(kwargs.get('vae_batch_size', 0)) > 0 else '', f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}' if int(kwargs.get('bucket_reso_steps', 64)) >= 1 else '', f' --min_snr_gamma={int(kwargs.get("min_snr_gamma", 0))}' if int(kwargs.get('min_snr_gamma', 0)) >= 1 else '', ' --save_state' if kwargs.get('save_state') else '', ' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '', ' --color_aug' if kwargs.get('color_aug') else '', ' --flip_aug' if kwargs.get('flip_aug') else '', ' --shuffle_caption' if kwargs.get('shuffle_caption') else '', ' --gradient_checkpointing' if kwargs.get('gradient_checkpointing') else '', ' --full_fp16' if kwargs.get('full_fp16') else '', ' --xformers' if kwargs.get('xformers') else '', # ' --use_8bit_adam' if kwargs.get('use_8bit_adam') else '', ' --persistent_data_loader_workers' if kwargs.get('persistent_data_loader_workers') else '', ' --bucket_no_upscale' if kwargs.get('bucket_no_upscale') else '', ' --random_crop' if kwargs.get('random_crop') else '', f' --noise_offset={float(kwargs.get("noise_offset", 0))}' if not kwargs.get('noise_offset', '') == '' else '', f' {kwargs.get("additional_parameters", "")}', ] run_cmd = ''.join(options) return run_cmd