From 9d2e3f85a2e35fefa4946b9b3aa6f0e9bb4e55da Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sun, 26 Feb 2023 19:49:22 -0500 Subject: [PATCH] Add tensorboard support --- README.md | 4 +++- dreambooth_gui.py | 17 ++++++++++++++- finetune_gui.py | 17 ++++++++++++++- library/common_gui.py | 47 ++++++++++++++++++++++++++++++++++++++++ lora_gui.py | 18 +++++++++++++-- textual_inversion_gui.py | 17 ++++++++++++++- 6 files changed, 114 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 2287dbc..7c847e6 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,9 @@ This will store your a backup file with your current locally installed pip packa ## Change History -* 2023/02/24 (v20.8.2): +* 2023/02/27 (v21.0.0): + - Add tensorboard start and stop support to the GUI +* 2023/02/26 (v20.8.2): - Fix issue https://github.com/bmaltais/kohya_ss/issues/231 - Change default for seed to random - Add support for --share argument to `kohya_gui.py` and `gui.ps1` diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 6c94f23..f41a1dd 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -25,6 +25,9 @@ from library.common_gui import ( gradio_config, gradio_source_model, set_legacy_8bitadam, + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -639,7 +642,19 @@ def dreambooth_tab( logging_dir_input=logging_dir, ) - button_run = gr.Button('Train model') + button_run = gr.Button('Train model', variant='primary') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=logging_dir, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + ) settings_list = [ pretrained_model_name_or_path, diff --git a/finetune_gui.py b/finetune_gui.py index 4cf81f2..55278ee 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -19,6 +19,9 @@ from library.common_gui import ( color_aug_changed, run_cmd_training, set_legacy_8bitadam, + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, ) from library.utilities import utilities_tab @@ -623,7 +626,19 @@ def finetune_tab(): outputs=[optimizer, use_8bit_adam], ) - button_run = gr.Button('Train model') + button_run = gr.Button('Train model', variant='primary') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=logging_dir, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + ) settings_list = [ pretrained_model_name_or_path, diff --git a/library/common_gui.py b/library/common_gui.py index 25089af..941cc3b 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -3,12 +3,59 @@ import os import gradio as gr from easygui import msgbox import shutil +import subprocess +import time folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 +##### +# tensorboard section +##### + +tensorboard_proc = None # I know... bad but heh + +def start_tensorboard(logging_dir): + global tensorboard_proc + + if not os.listdir(logging_dir): + print("Error: log folder is empty") + return + + run_cmd = f'tensorboard.exe --logdir "{logging_dir}"' + + print(run_cmd) + if tensorboard_proc is not None: + print("Tensorboard is already running. Terminating existing process before starting new one...") + stop_tensorboard() + + # Start background process + print('Starting tensorboard...') + tensorboard_proc = subprocess.Popen(run_cmd) + + # Wait for some time to allow TensorBoard to start up + time.sleep(5) + + # Open the TensorBoard URL in the default browser + print('Opening tensorboard url in browser...') + import webbrowser + webbrowser.open('http://localhost:6006') + +def stop_tensorboard(): + print('Stopping tensorboard process...') + tensorboard_proc.kill() + print('...process stopped') + +def gradio_tensorboard(): + with gr.Row(): + button_start_tensorboard = gr.Button('Start tensorboard') + button_stop_tensorboard = gr.Button('Stop tensorboard') + + return(button_start_tensorboard, button_stop_tensorboard) + +##### def get_dir_and_file(file_path): dir_path, file_name = os.path.split(file_path) diff --git a/lora_gui.py b/lora_gui.py index 8b15a94..651d062 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -25,6 +25,9 @@ from library.common_gui import ( gradio_source_model, run_cmd_training, set_legacy_8bitadam, + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -41,7 +44,6 @@ refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 - def save_configuration( save_as, file_path, @@ -745,7 +747,19 @@ def lora_tab( gradio_resize_lora_tab() gradio_verify_lora_tab() - button_run = gr.Button('Train model') + button_run = gr.Button('Train model', variant='primary') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=logging_dir, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + ) settings_list = [ pretrained_model_name_or_path, diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index 933c6f2..3bb8b93 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -25,6 +25,9 @@ from library.common_gui import ( gradio_config, gradio_source_model, set_legacy_8bitadam, + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -714,7 +717,19 @@ def ti_tab( logging_dir_input=logging_dir, ) - button_run = gr.Button('Train TI') + button_run = gr.Button('Train model', variant='primary') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=logging_dir, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + ) settings_list = [ pretrained_model_name_or_path,