diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 2e9cfdb..b843df3 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -435,40 +435,6 @@ def train_model( save_inference_file(output_dir, v2, v_parameterization, output_name) -def UI(username, password): - css = '' - - if os.path.exists('./style.css'): - with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: - print('Load CSS...') - css += file.read() + '\n' - - interface = gr.Blocks(css=css) - - with interface: - with gr.Tab('Dreambooth'): - ( - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - logging_dir_input, - ) = dreambooth_tab() - with gr.Tab('Utilities'): - utilities_tab( - train_data_dir_input=train_data_dir_input, - reg_data_dir_input=reg_data_dir_input, - output_dir_input=output_dir_input, - logging_dir_input=logging_dir_input, - enable_copy_info_button=True, - ) - - # Show the interface - if not username == '': - interface.launch(auth=(username, password)) - else: - interface.launch() - - def dreambooth_tab( train_data_dir=gr.Textbox(), reg_data_dir=gr.Textbox(), @@ -735,6 +701,44 @@ def dreambooth_tab( ) +def UI(**kwargs): + css = '' + + if os.path.exists('./style.css'): + with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: + print('Load CSS...') + css += file.read() + '\n' + + interface = gr.Blocks(css=css) + + with interface: + with gr.Tab('Dreambooth'): + ( + train_data_dir_input, + reg_data_dir_input, + output_dir_input, + logging_dir_input, + ) = dreambooth_tab() + with gr.Tab('Utilities'): + utilities_tab( + train_data_dir_input=train_data_dir_input, + reg_data_dir_input=reg_data_dir_input, + output_dir_input=output_dir_input, + logging_dir_input=logging_dir_input, + enable_copy_info_button=True, + ) + + # Show the interface + launch_kwargs={} + if not kwargs.get('username', None) == '': + launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None)) + if kwargs.get('server_port', 0) > 0: + launch_kwargs["server_port"] = kwargs.get('server_port', 0) + if kwargs.get('inbrowser', False): + launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False) + print(launch_kwargs) + interface.launch(**launch_kwargs) + if __name__ == '__main__': # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() @@ -744,7 +748,11 @@ if __name__ == '__main__': parser.add_argument( '--password', type=str, default='', help='Password for authentication' ) + parser.add_argument( + '--server_port', type=int, default=0, help='Port to run the server listener on' + ) + parser.add_argument("--inbrowser", action="store_true", help="Open in browser") args = parser.parse_args() - UI(username=args.username, password=args.password) + UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port) diff --git a/finetune_gui.py b/finetune_gui.py index f5aad9d..4c2ccec 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -431,30 +431,6 @@ def remove_doublequote(file_path): return file_path -def UI(username, password): - - css = '' - - if os.path.exists('./style.css'): - with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: - print('Load CSS...') - css += file.read() + '\n' - - interface = gr.Blocks(css=css) - - with interface: - with gr.Tab('Finetune'): - finetune_tab() - with gr.Tab('Utilities'): - utilities_tab(enable_dreambooth_tab=False) - - # Show the interface - if not username == '': - interface.launch(auth=(username, password)) - else: - interface.launch() - - def finetune_tab(): dummy_ft_true = gr.Label(value=True, visible=False) dummy_ft_false = gr.Label(value=False, visible=False) @@ -708,6 +684,35 @@ def finetune_tab(): ) +def UI(**kwargs): + + css = '' + + if os.path.exists('./style.css'): + with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: + print('Load CSS...') + css += file.read() + '\n' + + interface = gr.Blocks(css=css) + + with interface: + with gr.Tab('Finetune'): + finetune_tab() + with gr.Tab('Utilities'): + utilities_tab(enable_dreambooth_tab=False) + + # Show the interface + launch_kwargs={} + if not kwargs.get('username', None) == '': + launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None)) + if kwargs.get('server_port', 0) > 0: + launch_kwargs["server_port"] = kwargs.get('server_port', 0) + if kwargs.get('inbrowser', False): + launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False) + print(launch_kwargs) + interface.launch(**launch_kwargs) + + if __name__ == '__main__': # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() @@ -717,7 +722,11 @@ if __name__ == '__main__': parser.add_argument( '--password', type=str, default='', help='Password for authentication' ) + parser.add_argument( + '--server_port', type=int, default=0, help='Port to run the server listener on' + ) + parser.add_argument("--inbrowser", action="store_true", help="Open in browser") args = parser.parse_args() - UI(username=args.username, password=args.password) + UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port) diff --git a/gui.bat b/gui.bat index fbf5101..a7e8d08 100644 --- a/gui.bat +++ b/gui.bat @@ -1,10 +1,6 @@ @echo off -set VENV_DIR=.\venv -set PYTHON=python - -call %VENV_DIR%\Scripts\activate.bat - -%PYTHON% kohya_gui.py +call venv\Scripts\activate.bat +python.exe kohya_gui.py %* pause \ No newline at end of file diff --git a/gui.ps1 b/gui.ps1 index 4f799a1..fde5028 100644 --- a/gui.ps1 +++ b/gui.ps1 @@ -1,2 +1,2 @@ .\venv\Scripts\activate -python.exe kohya_gui.py \ No newline at end of file +python.exe kohya_gui.py $args \ No newline at end of file diff --git a/kohya_gui.py b/kohya_gui.py index b44c652..2c999a6 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -10,8 +10,7 @@ from library.merge_lora_gui import gradio_merge_lora_tab from lora_gui import lora_tab -def UI(username, password, inbrowser, server_port): - +def UI(**kwargs): css = '' if os.path.exists('./style.css'): @@ -47,13 +46,18 @@ def UI(username, password, inbrowser, server_port): gradio_merge_lora_tab() # Show the interface - kwargs = {} - if username: - kwargs["auth"] = (username, password) + launch_kwargs = {} + username = kwargs.get('username') + password = kwargs.get('password') + server_port = kwargs.get('server_port', 0) + inbrowser = kwargs.get('inbrowser', False) + if username and password: + launch_kwargs["auth"] = (username, password) if server_port > 0: - kwargs["server_port"] = server_port - kwargs["inbrowser"] = inbrowser - interface.launch(**kwargs) + launch_kwargs["server_port"] = server_port + if inbrowser: + launch_kwargs["inbrowser"] = inbrowser + interface.launch(**launch_kwargs) if __name__ == '__main__': # torch.cuda.set_per_process_memory_fraction(0.48) diff --git a/library/common_gui.py b/library/common_gui.py index a78532d..d52e07e 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -568,9 +568,11 @@ def gradio_advanced_training(): label="Dropout caption every n epochs", value=0 ) - caption_dropout_rate = gr.Number( + caption_dropout_rate = gr.Slider( label="Rate of caption dropout", - value=0 + value=0, + minimum=0, + maximum=1 ) with gr.Row(): save_state = gr.Checkbox(label='Save training state', value=False) diff --git a/library/utilities.py b/library/utilities.py index 523c2c2..a9fa5f4 100644 --- a/library/utilities.py +++ b/library/utilities.py @@ -36,7 +36,7 @@ def utilities_tab( ) -def UI(username, password): +def UI(**kwargs): css = '' if os.path.exists('./style.css'): @@ -50,11 +50,16 @@ def UI(username, password): utilities_tab() # Show the interface - if not username == '': - interface.launch(auth=(username, password)) - else: - interface.launch() - + launch_kwargs={} + if not kwargs.get('username', None) == '': + launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None)) + if kwargs.get('server_port', 0) > 0: + launch_kwargs["server_port"] = kwargs.get('server_port', 0) + if kwargs.get('inbrowser', False): + launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False) + print(launch_kwargs) + interface.launch(**launch_kwargs) + if __name__ == '__main__': # torch.cuda.set_per_process_memory_fraction(0.48) @@ -65,7 +70,11 @@ if __name__ == '__main__': parser.add_argument( '--password', type=str, default='', help='Password for authentication' ) + parser.add_argument( + '--server_port', type=int, default=0, help='Port to run the server listener on' + ) + parser.add_argument("--inbrowser", action="store_true", help="Open in browser") args = parser.parse_args() - UI(username=args.username, password=args.password) + UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port) diff --git a/lora_gui.py b/lora_gui.py index d48fb5a..a2ec7b7 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -495,40 +495,6 @@ def train_model( save_inference_file(output_dir, v2, v_parameterization, output_name) -def UI(username, password): - css = '' - - if os.path.exists('./style.css'): - with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: - print('Load CSS...') - css += file.read() + '\n' - - interface = gr.Blocks(css=css) - - with interface: - with gr.Tab('LoRA'): - ( - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - logging_dir_input, - ) = lora_tab() - with gr.Tab('Utilities'): - utilities_tab( - train_data_dir_input=train_data_dir_input, - reg_data_dir_input=reg_data_dir_input, - output_dir_input=output_dir_input, - logging_dir_input=logging_dir_input, - enable_copy_info_button=True, - ) - - # Show the interface - if not username == '': - interface.launch(auth=(username, password)) - else: - interface.launch() - - def lora_tab( train_data_dir_input=gr.Textbox(), reg_data_dir_input=gr.Textbox(), @@ -644,7 +610,7 @@ def lora_tab( caption_extension, cache_latents, ) = gradio_training( - learning_rate_value='1e-5', + learning_rate_value='0.0001', lr_scheduler_value='cosine', lr_warmup_value='10', ) @@ -656,7 +622,7 @@ def lora_tab( ) unet_lr = gr.Textbox( label='Unet learning rate', - value='1e-3', + value='0.0001', placeholder='Optional', ) network_dim = gr.Slider( @@ -845,6 +811,45 @@ def lora_tab( ) +def UI(**kwargs): + css = '' + + if os.path.exists('./style.css'): + with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: + print('Load CSS...') + css += file.read() + '\n' + + interface = gr.Blocks(css=css) + + with interface: + with gr.Tab('LoRA'): + ( + train_data_dir_input, + reg_data_dir_input, + output_dir_input, + logging_dir_input, + ) = lora_tab() + with gr.Tab('Utilities'): + utilities_tab( + train_data_dir_input=train_data_dir_input, + reg_data_dir_input=reg_data_dir_input, + output_dir_input=output_dir_input, + logging_dir_input=logging_dir_input, + enable_copy_info_button=True, + ) + + # Show the interface + launch_kwargs={} + if not kwargs.get('username', None) == '': + launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None)) + if kwargs.get('server_port', 0) > 0: + launch_kwargs["server_port"] = kwargs.get('server_port', 0) + if kwargs.get('inbrowser', False): + launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False) + print(launch_kwargs) + interface.launch(**launch_kwargs) + + if __name__ == '__main__': # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() @@ -854,7 +859,11 @@ if __name__ == '__main__': parser.add_argument( '--password', type=str, default='', help='Password for authentication' ) + parser.add_argument( + '--server_port', type=int, default=0, help='Port to run the server listener on' + ) + parser.add_argument("--inbrowser", action="store_true", help="Open in browser") args = parser.parse_args() - UI(username=args.username, password=args.password) + UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port) diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index d7b86ef..336be06 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -481,40 +481,6 @@ def train_model( save_inference_file(output_dir, v2, v_parameterization, output_name) -def UI(username, password): - css = '' - - if os.path.exists('./style.css'): - with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: - print('Load CSS...') - css += file.read() + '\n' - - interface = gr.Blocks(css=css) - - with interface: - with gr.Tab('Dreambooth TI'): - ( - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - logging_dir_input, - ) = ti_tab() - with gr.Tab('Utilities'): - utilities_tab( - train_data_dir_input=train_data_dir_input, - reg_data_dir_input=reg_data_dir_input, - output_dir_input=output_dir_input, - logging_dir_input=logging_dir_input, - enable_copy_info_button=True, - ) - - # Show the interface - if not username == '': - interface.launch(auth=(username, password)) - else: - interface.launch() - - def ti_tab( train_data_dir=gr.Textbox(), reg_data_dir=gr.Textbox(), @@ -823,6 +789,45 @@ def ti_tab( ) +def UI(**kwargs): + css = '' + + if os.path.exists('./style.css'): + with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: + print('Load CSS...') + css += file.read() + '\n' + + interface = gr.Blocks(css=css) + + with interface: + with gr.Tab('Dreambooth TI'): + ( + train_data_dir_input, + reg_data_dir_input, + output_dir_input, + logging_dir_input, + ) = ti_tab() + with gr.Tab('Utilities'): + utilities_tab( + train_data_dir_input=train_data_dir_input, + reg_data_dir_input=reg_data_dir_input, + output_dir_input=output_dir_input, + logging_dir_input=logging_dir_input, + enable_copy_info_button=True, + ) + + # Show the interface + launch_kwargs={} + if not kwargs.get('username', None) == '': + launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None)) + if kwargs.get('server_port', 0) > 0: + launch_kwargs["server_port"] = kwargs.get('server_port', 0) + if kwargs.get('inbrowser', False): + launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False) + print(launch_kwargs) + interface.launch(**launch_kwargs) + + if __name__ == '__main__': # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() @@ -832,7 +837,11 @@ if __name__ == '__main__': parser.add_argument( '--password', type=str, default='', help='Password for authentication' ) + parser.add_argument( + '--server_port', type=int, default=0, help='Port to run the server listener on' + ) + parser.add_argument("--inbrowser", action="store_true", help="Open in browser") args = parser.parse_args() - UI(username=args.username, password=args.password) + UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)