From 9d2e3f85a2e35fefa4946b9b3aa6f0e9bb4e55da Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sun, 26 Feb 2023 19:49:22 -0500 Subject: [PATCH 1/6] Add tensorboard support --- README.md | 4 +++- dreambooth_gui.py | 17 ++++++++++++++- finetune_gui.py | 17 ++++++++++++++- library/common_gui.py | 47 ++++++++++++++++++++++++++++++++++++++++ lora_gui.py | 18 +++++++++++++-- textual_inversion_gui.py | 17 ++++++++++++++- 6 files changed, 114 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 2287dbc..7c847e6 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,9 @@ This will store your a backup file with your current locally installed pip packa ## Change History -* 2023/02/24 (v20.8.2): +* 2023/02/27 (v21.0.0): + - Add tensorboard start and stop support to the GUI +* 2023/02/26 (v20.8.2): - Fix issue https://github.com/bmaltais/kohya_ss/issues/231 - Change default for seed to random - Add support for --share argument to `kohya_gui.py` and `gui.ps1` diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 6c94f23..f41a1dd 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -25,6 +25,9 @@ from library.common_gui import ( gradio_config, gradio_source_model, set_legacy_8bitadam, + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -639,7 +642,19 @@ def dreambooth_tab( logging_dir_input=logging_dir, ) - button_run = gr.Button('Train model') + button_run = gr.Button('Train model', variant='primary') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=logging_dir, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + ) settings_list = [ pretrained_model_name_or_path, diff --git a/finetune_gui.py b/finetune_gui.py index 4cf81f2..55278ee 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -19,6 +19,9 @@ from library.common_gui import ( color_aug_changed, run_cmd_training, set_legacy_8bitadam, + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, ) from library.utilities import utilities_tab @@ -623,7 +626,19 @@ def finetune_tab(): outputs=[optimizer, use_8bit_adam], ) - button_run = gr.Button('Train model') + button_run = gr.Button('Train model', variant='primary') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=logging_dir, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + ) settings_list = [ pretrained_model_name_or_path, diff --git a/library/common_gui.py b/library/common_gui.py index 25089af..941cc3b 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -3,12 +3,59 @@ import os import gradio as gr from easygui import msgbox import shutil +import subprocess +import time folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 +##### +# tensorboard section +##### + +tensorboard_proc = None # I know... bad but heh + +def start_tensorboard(logging_dir): + global tensorboard_proc + + if not os.listdir(logging_dir): + print("Error: log folder is empty") + return + + run_cmd = f'tensorboard.exe --logdir "{logging_dir}"' + + print(run_cmd) + if tensorboard_proc is not None: + print("Tensorboard is already running. Terminating existing process before starting new one...") + stop_tensorboard() + + # Start background process + print('Starting tensorboard...') + tensorboard_proc = subprocess.Popen(run_cmd) + + # Wait for some time to allow TensorBoard to start up + time.sleep(5) + + # Open the TensorBoard URL in the default browser + print('Opening tensorboard url in browser...') + import webbrowser + webbrowser.open('http://localhost:6006') + +def stop_tensorboard(): + print('Stopping tensorboard process...') + tensorboard_proc.kill() + print('...process stopped') + +def gradio_tensorboard(): + with gr.Row(): + button_start_tensorboard = gr.Button('Start tensorboard') + button_stop_tensorboard = gr.Button('Stop tensorboard') + + return(button_start_tensorboard, button_stop_tensorboard) + +##### def get_dir_and_file(file_path): dir_path, file_name = os.path.split(file_path) diff --git a/lora_gui.py b/lora_gui.py index 8b15a94..651d062 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -25,6 +25,9 @@ from library.common_gui import ( gradio_source_model, run_cmd_training, set_legacy_8bitadam, + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -41,7 +44,6 @@ refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 - def save_configuration( save_as, file_path, @@ -745,7 +747,19 @@ def lora_tab( gradio_resize_lora_tab() gradio_verify_lora_tab() - button_run = gr.Button('Train model') + button_run = gr.Button('Train model', variant='primary') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=logging_dir, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + ) settings_list = [ pretrained_model_name_or_path, diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index 933c6f2..3bb8b93 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -25,6 +25,9 @@ from library.common_gui import ( gradio_config, gradio_source_model, set_legacy_8bitadam, + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -714,7 +717,19 @@ def ti_tab( logging_dir_input=logging_dir, ) - button_run = gr.Button('Train TI') + button_run = gr.Button('Train model', variant='primary') + + # Setup gradio tensorboard buttons + button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=logging_dir, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + ) settings_list = [ pretrained_model_name_or_path, From 6e664f11769f77debe991395330befcc2646f42e Mon Sep 17 00:00:00 2001 From: Ki-wimon <40882134+Ki-wimon@users.noreply.github.com> Date: Tue, 28 Feb 2023 01:16:23 +0800 Subject: [PATCH 2/6] support locon --- lora_gui.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/lora_gui.py b/lora_gui.py index 8b15a94..29523a7 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -10,6 +10,7 @@ import os import subprocess import pathlib import argparse +import shutil from library.common_gui import ( get_folder_path, remove_doublequote, @@ -40,7 +41,12 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 +locon_path = os.getcwd()+'\\locon\\' +def getlocon(): + os.system('git clone https://github.com/KohakuBlueleaf/LoCon.git') + os.system('ren '+locon_path[:-6]+'\\LoCon\\'+' locon_github-sourcecode') + shutil.copytree(locon_path[:-6]+'locon_github-sourcecode\\locon\\', locon_path) def save_configuration( save_as, @@ -102,7 +108,7 @@ def save_configuration( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args,noise_offset, + optimizer_args,noise_offset, locon = 0 ): # Get list of function parameters and values parameters = list(locals().items()) @@ -208,7 +214,7 @@ def open_configuration( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args,noise_offset, + optimizer_args,noise_offset, locon=0 ): # Get list of function parameters and values parameters = list(locals().items()) @@ -292,7 +298,7 @@ def train_model( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args,noise_offset, + optimizer_args,noise_offset, locon ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -428,7 +434,12 @@ def train_model( run_cmd += f' --save_model_as={save_model_as}' if not float(prior_loss_weight) == 1.0: run_cmd += f' --prior_loss_weight={prior_loss_weight}' - run_cmd += f' --network_module=networks.lora' + if locon: + if not os.path.exists(locon_path): + getlocon() + run_cmd += ' --network_module=locon.locon_kohya' + else: + run_cmd += f' --network_module=networks.lora' if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0): if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0): @@ -676,6 +687,8 @@ def lora_tab( ) enable_bucket = gr.Checkbox(label='Enable buckets', value=True) with gr.Accordion('Advanced Configuration', open=False): + with gr.Row(): + locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA', value=False) with gr.Row(): no_token_padding = gr.Checkbox( label='No token padding', value=False @@ -805,7 +818,7 @@ def lora_tab( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args,noise_offset, + optimizer_args,noise_offset,locon ] button_open_config.click( From c32a99dad5244a1cad03fdb0820d154cb8fdc143 Mon Sep 17 00:00:00 2001 From: Ki-wimon <40882134+Ki-wimon@users.noreply.github.com> Date: Tue, 28 Feb 2023 01:38:05 +0800 Subject: [PATCH 3/6] Update lora_gui.py --- lora_gui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lora_gui.py b/lora_gui.py index 29523a7..d2c867a 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -688,7 +688,7 @@ def lora_tab( enable_bucket = gr.Checkbox(label='Enable buckets', value=True) with gr.Accordion('Advanced Configuration', open=False): with gr.Row(): - locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA', value=False) + locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (may not be able to merge now)', value=False) with gr.Row(): no_token_padding = gr.Checkbox( label='No token padding', value=False From dfd155a8e15ffbec1fd12401172d0c9e80ceba55 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Tue, 28 Feb 2023 07:37:19 -0500 Subject: [PATCH 4/6] Undo LoCon commit --- lora_gui.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/lora_gui.py b/lora_gui.py index 99c6609..eb0b567 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -10,7 +10,6 @@ import os import subprocess import pathlib import argparse -import shutil from library.common_gui import ( get_folder_path, remove_doublequote, @@ -44,12 +43,6 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 -locon_path = os.getcwd()+'\\locon\\' - -def getlocon(): - os.system('git clone https://github.com/KohakuBlueleaf/LoCon.git') - os.system('ren '+locon_path[:-6]+'\\LoCon\\'+' locon_github-sourcecode') - shutil.copytree(locon_path[:-6]+'locon_github-sourcecode\\locon\\', locon_path) def save_configuration( save_as, @@ -111,7 +104,7 @@ def save_configuration( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args,noise_offset, locon = 0 + optimizer_args,noise_offset, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -217,7 +210,7 @@ def open_configuration( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args,noise_offset, locon=0 + optimizer_args,noise_offset, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -301,7 +294,7 @@ def train_model( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args,noise_offset, locon + optimizer_args,noise_offset, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -437,12 +430,7 @@ def train_model( run_cmd += f' --save_model_as={save_model_as}' if not float(prior_loss_weight) == 1.0: run_cmd += f' --prior_loss_weight={prior_loss_weight}' - if locon: - if not os.path.exists(locon_path): - getlocon() - run_cmd += ' --network_module=locon.locon_kohya' - else: - run_cmd += f' --network_module=networks.lora' + run_cmd += f' --network_module=networks.lora' if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0): if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0): @@ -690,8 +678,6 @@ def lora_tab( ) enable_bucket = gr.Checkbox(label='Enable buckets', value=True) with gr.Accordion('Advanced Configuration', open=False): - with gr.Row(): - locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (may not be able to merge now)', value=False) with gr.Row(): no_token_padding = gr.Checkbox( label='No token padding', value=False @@ -833,7 +819,7 @@ def lora_tab( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args,noise_offset,locon + optimizer_args,noise_offset, ] button_open_config.click( @@ -922,4 +908,4 @@ if __name__ == '__main__': args = parser.parse_args() - UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port) + UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port) \ No newline at end of file From 1e3055c895982bb2d48adf7daa1574e9362ce0f5 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Wed, 1 Mar 2023 13:14:47 -0500 Subject: [PATCH 5/6] Update tensorboard --- dreambooth_gui.py | 2 ++ finetune_gui.py | 2 ++ library/common_gui.py | 47 -------------------------------------- library/tensorboard_gui.py | 46 +++++++++++++++++++++++++++++++++++++ lora_gui.py | 8 ++++--- textual_inversion_gui.py | 2 ++ tools/rename_depth_mask.py | 21 +++++++++++++++++ 7 files changed, 78 insertions(+), 50 deletions(-) create mode 100644 library/tensorboard_gui.py create mode 100644 tools/rename_depth_mask.py diff --git a/dreambooth_gui.py b/dreambooth_gui.py index f41a1dd..1e76d89 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -25,6 +25,8 @@ from library.common_gui import ( gradio_config, gradio_source_model, set_legacy_8bitadam, +) +from library.tensorboard_gui import ( gradio_tensorboard, start_tensorboard, stop_tensorboard, diff --git a/finetune_gui.py b/finetune_gui.py index 55278ee..7777a7f 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -19,6 +19,8 @@ from library.common_gui import ( color_aug_changed, run_cmd_training, set_legacy_8bitadam, +) +from library.tensorboard_gui import ( gradio_tensorboard, start_tensorboard, stop_tensorboard, diff --git a/library/common_gui.py b/library/common_gui.py index 941cc3b..25089af 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -3,59 +3,12 @@ import os import gradio as gr from easygui import msgbox import shutil -import subprocess -import time folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 -##### -# tensorboard section -##### - -tensorboard_proc = None # I know... bad but heh - -def start_tensorboard(logging_dir): - global tensorboard_proc - - if not os.listdir(logging_dir): - print("Error: log folder is empty") - return - - run_cmd = f'tensorboard.exe --logdir "{logging_dir}"' - - print(run_cmd) - if tensorboard_proc is not None: - print("Tensorboard is already running. Terminating existing process before starting new one...") - stop_tensorboard() - - # Start background process - print('Starting tensorboard...') - tensorboard_proc = subprocess.Popen(run_cmd) - - # Wait for some time to allow TensorBoard to start up - time.sleep(5) - - # Open the TensorBoard URL in the default browser - print('Opening tensorboard url in browser...') - import webbrowser - webbrowser.open('http://localhost:6006') - -def stop_tensorboard(): - print('Stopping tensorboard process...') - tensorboard_proc.kill() - print('...process stopped') - -def gradio_tensorboard(): - with gr.Row(): - button_start_tensorboard = gr.Button('Start tensorboard') - button_stop_tensorboard = gr.Button('Stop tensorboard') - - return(button_start_tensorboard, button_stop_tensorboard) - -##### def get_dir_and_file(file_path): dir_path, file_name = os.path.split(file_path) diff --git a/library/tensorboard_gui.py b/library/tensorboard_gui.py new file mode 100644 index 0000000..fa90a1c --- /dev/null +++ b/library/tensorboard_gui.py @@ -0,0 +1,46 @@ +import os +import gradio as gr +from easygui import msgbox +import subprocess +import time + +tensorboard_proc = None # I know... bad but heh + +def start_tensorboard(logging_dir): + global tensorboard_proc + + if not os.listdir(logging_dir): + print("Error: log folder is empty") + msgbox(msg="Error: log folder is empty") + return + + run_cmd = f'tensorboard.exe --logdir "{logging_dir}"' + + print(run_cmd) + if tensorboard_proc is not None: + print("Tensorboard is already running. Terminating existing process before starting new one...") + stop_tensorboard() + + # Start background process + print('Starting tensorboard...') + tensorboard_proc = subprocess.Popen(run_cmd) + + # Wait for some time to allow TensorBoard to start up + time.sleep(5) + + # Open the TensorBoard URL in the default browser + print('Opening tensorboard url in browser...') + import webbrowser + webbrowser.open('http://localhost:6006') + +def stop_tensorboard(): + print('Stopping tensorboard process...') + tensorboard_proc.kill() + print('...process stopped') + +def gradio_tensorboard(): + with gr.Row(): + button_start_tensorboard = gr.Button('Start tensorboard') + button_stop_tensorboard = gr.Button('Stop tensorboard') + + return(button_start_tensorboard, button_stop_tensorboard) diff --git a/lora_gui.py b/lora_gui.py index eb0b567..b1f3f34 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -25,13 +25,15 @@ from library.common_gui import ( gradio_source_model, run_cmd_training, set_legacy_8bitadam, - gradio_tensorboard, - start_tensorboard, - stop_tensorboard, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, ) +from library.tensorboard_gui import ( + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, +) from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.utilities import utilities_tab from library.merge_lora_gui import gradio_merge_lora_tab diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index 3bb8b93..aaa4df3 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -25,6 +25,8 @@ from library.common_gui import ( gradio_config, gradio_source_model, set_legacy_8bitadam, +) +from library.tensorboard_gui import ( gradio_tensorboard, start_tensorboard, stop_tensorboard, diff --git a/tools/rename_depth_mask.py b/tools/rename_depth_mask.py new file mode 100644 index 0000000..97efdea --- /dev/null +++ b/tools/rename_depth_mask.py @@ -0,0 +1,21 @@ +import os +import argparse + +# Define the command line arguments +parser = argparse.ArgumentParser(description='Rename files in a folder') +parser.add_argument('folder', metavar='folder', type=str, help='the folder containing the files to rename') + +# Parse the arguments +args = parser.parse_args() + +# Get the list of files in the folder +files = os.listdir(args.folder) + +# Loop through each file in the folder +for file in files: + # Check if the file has the expected format + if file.endswith('-0000.png'): + # Get the new file name + new_file_name = file[:-9] + '.mask' + # Rename the file + os.rename(os.path.join(args.folder, file), os.path.join(args.folder, new_file_name)) From 7f0e5683c6b4e784f4a20647dbfa45aaa579467c Mon Sep 17 00:00:00 2001 From: bmaltais Date: Wed, 1 Mar 2023 19:02:04 -0500 Subject: [PATCH 6/6] v21.0.1 --- README.md | 3 +++ dreambooth_gui.py | 3 +++ finetune_gui.py | 18 ++++++++++-------- library/common_gui.py | 9 ++++++++- lora_gui.py | 3 +++ textual_inversion_gui.py | 3 +++ 6 files changed, 30 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 7c847e6..f7a97ae 100644 --- a/README.md +++ b/README.md @@ -163,6 +163,9 @@ This will store your a backup file with your current locally installed pip packa ## Change History +* 2023/03/01 (v21.0.1): + - Add warning to tensorboard start if the log information is missing + - Fix issue with 8bitadam on older config file load * 2023/02/27 (v21.0.0): - Add tensorboard start and stop support to the GUI * 2023/02/26 (v20.8.2): diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 1e76d89..bf48991 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -25,6 +25,7 @@ from library.common_gui import ( gradio_config, gradio_source_model, set_legacy_8bitadam, + update_optimizer, ) from library.tensorboard_gui import ( gradio_tensorboard, @@ -208,6 +209,8 @@ def open_configuration( with open(file_path, 'r') as f: my_data_db = json.load(f) print('Loading config...') + # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True + my_data = update_optimizer(my_data) else: file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action my_data_db = {} diff --git a/finetune_gui.py b/finetune_gui.py index 7777a7f..d1f6393 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -19,6 +19,7 @@ from library.common_gui import ( color_aug_changed, run_cmd_training, set_legacy_8bitadam, + update_optimizer, ) from library.tensorboard_gui import ( gradio_tensorboard, @@ -203,21 +204,22 @@ def open_config_file( original_file_path = file_path file_path = get_file_path(file_path) - if file_path != '' and file_path != None: - print(f'Loading config file {file_path}') + if not file_path == '' and not file_path == None: # load variables from JSON file with open(file_path, 'r') as f: - my_data_ft = json.load(f) + my_data_db = json.load(f) + print('Loading config...') + # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True + my_data = update_optimizer(my_data) else: - file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action - my_data_ft = {} + file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action + my_data_db = {} values = [file_path] for key, value in parameters: - # Set the value in the dictionary to the corresponding value in `my_data_ft`, or the default value if not found + # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found if not key in ['file_path']: - values.append(my_data_ft.get(key, value)) - # print(values) + values.append(my_data_db.get(key, value)) return tuple(values) diff --git a/library/common_gui.py b/library/common_gui.py index 25089af..69240c1 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -9,6 +9,12 @@ refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 +def update_optimizer(my_data): + if my_data.get('use_8bit_adam', False): + my_data['optimizer'] = 'AdamW8bit' + my_data['use_8bit_adam'] = False + return my_data + def get_dir_and_file(file_path): dir_path, file_name = os.path.split(file_path) @@ -604,7 +610,8 @@ def gradio_advanced_training(): label='Memory efficient attention', value=False ) with gr.Row(): - use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) + # This use_8bit_adam element should be removed in a future release as it is no longer used + use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=False, visible=False) xformers = gr.Checkbox(label='Use xformers', value=True) color_aug = gr.Checkbox(label='Color augmentation', value=False) flip_aug = gr.Checkbox(label='Flip augmentation', value=False) diff --git a/lora_gui.py b/lora_gui.py index b1f3f34..ac7d69e 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -25,6 +25,7 @@ from library.common_gui import ( gradio_source_model, run_cmd_training, set_legacy_8bitadam, + update_optimizer, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -225,6 +226,8 @@ def open_configuration( with open(file_path, 'r') as f: my_data = json.load(f) print('Loading config...') + # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True + my_data = update_optimizer(my_data) else: file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action my_data = {} diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index aaa4df3..cc93beb 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -25,6 +25,7 @@ from library.common_gui import ( gradio_config, gradio_source_model, set_legacy_8bitadam, + update_optimizer, ) from library.tensorboard_gui import ( gradio_tensorboard, @@ -218,6 +219,8 @@ def open_configuration( with open(file_path, 'r') as f: my_data_db = json.load(f) print('Loading config...') + # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True + my_data = update_optimizer(my_data) else: file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action my_data_db = {}