diff --git a/README.md b/README.md index 2287dbc..f7a97ae 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,12 @@ This will store your a backup file with your current locally installed pip packa ## Change History -* 2023/02/24 (v20.8.2): +* 2023/03/01 (v21.0.1): + - Add warning to tensorboard start if the log information is missing + - Fix issue with 8bitadam on older config file load +* 2023/02/27 (v21.0.0): + - Add tensorboard start and stop support to the GUI +* 2023/02/26 (v20.8.2): - Fix issue https://github.com/bmaltais/kohya_ss/issues/231 - Change default for seed to random - Add support for --share argument to `kohya_gui.py` and `gui.ps1` diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 6c94f23..bf48991 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -25,6 +25,12 @@ from library.common_gui import ( gradio_config, gradio_source_model, set_legacy_8bitadam, + update_optimizer, +) +from library.tensorboard_gui import ( + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -203,6 +209,8 @@ def open_configuration( with open(file_path, 'r') as f: my_data_db = json.load(f) print('Loading config...') + # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True + my_data = update_optimizer(my_data) else: file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action my_data_db = {} @@ -639,7 +647,19 @@ def dreambooth_tab( logging_dir_input=logging_dir, ) - button_run = gr.Button('Train model') + button_run = gr.Button('Train model', variant='primary') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=logging_dir, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + ) settings_list = [ pretrained_model_name_or_path, diff --git a/finetune_gui.py b/finetune_gui.py index 4cf81f2..d1f6393 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -19,6 +19,12 @@ from library.common_gui import ( color_aug_changed, run_cmd_training, set_legacy_8bitadam, + update_optimizer, +) +from library.tensorboard_gui import ( + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, ) from library.utilities import utilities_tab @@ -198,21 +204,22 @@ def open_config_file( original_file_path = file_path file_path = get_file_path(file_path) - if file_path != '' and file_path != None: - print(f'Loading config file {file_path}') + if not file_path == '' and not file_path == None: # load variables from JSON file with open(file_path, 'r') as f: - my_data_ft = json.load(f) + my_data_db = json.load(f) + print('Loading config...') + # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True + my_data = update_optimizer(my_data) else: - file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action - my_data_ft = {} + file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action + my_data_db = {} values = [file_path] for key, value in parameters: - # Set the value in the dictionary to the corresponding value in `my_data_ft`, or the default value if not found + # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found if not key in ['file_path']: - values.append(my_data_ft.get(key, value)) - # print(values) + values.append(my_data_db.get(key, value)) return tuple(values) @@ -623,7 +630,19 @@ def finetune_tab(): outputs=[optimizer, use_8bit_adam], ) - button_run = gr.Button('Train model') + button_run = gr.Button('Train model', variant='primary') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=logging_dir, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + ) settings_list = [ pretrained_model_name_or_path, diff --git a/library/common_gui.py b/library/common_gui.py index 25089af..69240c1 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -9,6 +9,12 @@ refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 +def update_optimizer(my_data): + if my_data.get('use_8bit_adam', False): + my_data['optimizer'] = 'AdamW8bit' + my_data['use_8bit_adam'] = False + return my_data + def get_dir_and_file(file_path): dir_path, file_name = os.path.split(file_path) @@ -604,7 +610,8 @@ def gradio_advanced_training(): label='Memory efficient attention', value=False ) with gr.Row(): - use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) + # This use_8bit_adam element should be removed in a future release as it is no longer used + use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=False, visible=False) xformers = gr.Checkbox(label='Use xformers', value=True) color_aug = gr.Checkbox(label='Color augmentation', value=False) flip_aug = gr.Checkbox(label='Flip augmentation', value=False) diff --git a/library/tensorboard_gui.py b/library/tensorboard_gui.py new file mode 100644 index 0000000..fa90a1c --- /dev/null +++ b/library/tensorboard_gui.py @@ -0,0 +1,46 @@ +import os +import gradio as gr +from easygui import msgbox +import subprocess +import time + +tensorboard_proc = None # I know... bad but heh + +def start_tensorboard(logging_dir): + global tensorboard_proc + + if not os.listdir(logging_dir): + print("Error: log folder is empty") + msgbox(msg="Error: log folder is empty") + return + + run_cmd = f'tensorboard.exe --logdir "{logging_dir}"' + + print(run_cmd) + if tensorboard_proc is not None: + print("Tensorboard is already running. Terminating existing process before starting new one...") + stop_tensorboard() + + # Start background process + print('Starting tensorboard...') + tensorboard_proc = subprocess.Popen(run_cmd) + + # Wait for some time to allow TensorBoard to start up + time.sleep(5) + + # Open the TensorBoard URL in the default browser + print('Opening tensorboard url in browser...') + import webbrowser + webbrowser.open('http://localhost:6006') + +def stop_tensorboard(): + print('Stopping tensorboard process...') + tensorboard_proc.kill() + print('...process stopped') + +def gradio_tensorboard(): + with gr.Row(): + button_start_tensorboard = gr.Button('Start tensorboard') + button_stop_tensorboard = gr.Button('Stop tensorboard') + + return(button_start_tensorboard, button_stop_tensorboard) diff --git a/lora_gui.py b/lora_gui.py index 8b15a94..ac7d69e 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -25,10 +25,16 @@ from library.common_gui import ( gradio_source_model, run_cmd_training, set_legacy_8bitadam, + update_optimizer, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, ) +from library.tensorboard_gui import ( + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, +) from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.utilities import utilities_tab from library.merge_lora_gui import gradio_merge_lora_tab @@ -41,7 +47,6 @@ refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 - def save_configuration( save_as, file_path, @@ -221,6 +226,8 @@ def open_configuration( with open(file_path, 'r') as f: my_data = json.load(f) print('Loading config...') + # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True + my_data = update_optimizer(my_data) else: file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action my_data = {} @@ -745,7 +752,19 @@ def lora_tab( gradio_resize_lora_tab() gradio_verify_lora_tab() - button_run = gr.Button('Train model') + button_run = gr.Button('Train model', variant='primary') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=logging_dir, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + ) settings_list = [ pretrained_model_name_or_path, @@ -894,4 +913,4 @@ if __name__ == '__main__': args = parser.parse_args() - UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port) + UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port) \ No newline at end of file diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index 933c6f2..cc93beb 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -25,6 +25,12 @@ from library.common_gui import ( gradio_config, gradio_source_model, set_legacy_8bitadam, + update_optimizer, +) +from library.tensorboard_gui import ( + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -213,6 +219,8 @@ def open_configuration( with open(file_path, 'r') as f: my_data_db = json.load(f) print('Loading config...') + # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True + my_data = update_optimizer(my_data) else: file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action my_data_db = {} @@ -714,7 +722,19 @@ def ti_tab( logging_dir_input=logging_dir, ) - button_run = gr.Button('Train TI') + button_run = gr.Button('Train model', variant='primary') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=logging_dir, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + ) settings_list = [ pretrained_model_name_or_path, diff --git a/tools/rename_depth_mask.py b/tools/rename_depth_mask.py new file mode 100644 index 0000000..97efdea --- /dev/null +++ b/tools/rename_depth_mask.py @@ -0,0 +1,21 @@ +import os +import argparse + +# Define the command line arguments +parser = argparse.ArgumentParser(description='Rename files in a folder') +parser.add_argument('folder', metavar='folder', type=str, help='the folder containing the files to rename') + +# Parse the arguments +args = parser.parse_args() + +# Get the list of files in the folder +files = os.listdir(args.folder) + +# Loop through each file in the folder +for file in files: + # Check if the file has the expected format + if file.endswith('-0000.png'): + # Get the new file name + new_file_name = file[:-9] + '.mask' + # Rename the file + os.rename(os.path.join(args.folder, file), os.path.join(args.folder, new_file_name))