Add server_port and inbrowser support

- to all gui scripts
This commit is contained in:
bmaltais 2023-02-10 08:22:03 -05:00
parent 56d171c55b
commit e5f8ba559f
9 changed files with 202 additions and 156 deletions

View File

@ -435,40 +435,6 @@ def train_model(
save_inference_file(output_dir, v2, v_parameterization, output_name) save_inference_file(output_dir, v2, v_parameterization, output_name)
def UI(username, password):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('Dreambooth'):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = dreambooth_tab()
with gr.Tab('Utilities'):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
)
# Show the interface
if not username == '':
interface.launch(auth=(username, password))
else:
interface.launch()
def dreambooth_tab( def dreambooth_tab(
train_data_dir=gr.Textbox(), train_data_dir=gr.Textbox(),
reg_data_dir=gr.Textbox(), reg_data_dir=gr.Textbox(),
@ -735,6 +701,44 @@ def dreambooth_tab(
) )
def UI(**kwargs):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('Dreambooth'):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = dreambooth_tab()
with gr.Tab('Utilities'):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
)
# Show the interface
launch_kwargs={}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)
if __name__ == '__main__': if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48) # torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -744,7 +748,11 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
'--password', type=str, default='', help='Password for authentication' '--password', type=str, default='', help='Password for authentication'
) )
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
args = parser.parse_args() args = parser.parse_args()
UI(username=args.username, password=args.password) UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)

View File

@ -431,30 +431,6 @@ def remove_doublequote(file_path):
return file_path return file_path
def UI(username, password):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('Finetune'):
finetune_tab()
with gr.Tab('Utilities'):
utilities_tab(enable_dreambooth_tab=False)
# Show the interface
if not username == '':
interface.launch(auth=(username, password))
else:
interface.launch()
def finetune_tab(): def finetune_tab():
dummy_ft_true = gr.Label(value=True, visible=False) dummy_ft_true = gr.Label(value=True, visible=False)
dummy_ft_false = gr.Label(value=False, visible=False) dummy_ft_false = gr.Label(value=False, visible=False)
@ -708,6 +684,35 @@ def finetune_tab():
) )
def UI(**kwargs):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('Finetune'):
finetune_tab()
with gr.Tab('Utilities'):
utilities_tab(enable_dreambooth_tab=False)
# Show the interface
launch_kwargs={}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)
if __name__ == '__main__': if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48) # torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -717,7 +722,11 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
'--password', type=str, default='', help='Password for authentication' '--password', type=str, default='', help='Password for authentication'
) )
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
args = parser.parse_args() args = parser.parse_args()
UI(username=args.username, password=args.password) UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)

View File

@ -1,10 +1,6 @@
@echo off @echo off
set VENV_DIR=.\venv call venv\Scripts\activate.bat
set PYTHON=python python.exe kohya_gui.py %*
call %VENV_DIR%\Scripts\activate.bat
%PYTHON% kohya_gui.py
pause pause

View File

@ -1,2 +1,2 @@
.\venv\Scripts\activate .\venv\Scripts\activate
python.exe kohya_gui.py python.exe kohya_gui.py $args

View File

@ -10,8 +10,7 @@ from library.merge_lora_gui import gradio_merge_lora_tab
from lora_gui import lora_tab from lora_gui import lora_tab
def UI(username, password, inbrowser, server_port): def UI(**kwargs):
css = '' css = ''
if os.path.exists('./style.css'): if os.path.exists('./style.css'):
@ -47,13 +46,18 @@ def UI(username, password, inbrowser, server_port):
gradio_merge_lora_tab() gradio_merge_lora_tab()
# Show the interface # Show the interface
kwargs = {} launch_kwargs = {}
if username: username = kwargs.get('username')
kwargs["auth"] = (username, password) password = kwargs.get('password')
server_port = kwargs.get('server_port', 0)
inbrowser = kwargs.get('inbrowser', False)
if username and password:
launch_kwargs["auth"] = (username, password)
if server_port > 0: if server_port > 0:
kwargs["server_port"] = server_port launch_kwargs["server_port"] = server_port
kwargs["inbrowser"] = inbrowser if inbrowser:
interface.launch(**kwargs) launch_kwargs["inbrowser"] = inbrowser
interface.launch(**launch_kwargs)
if __name__ == '__main__': if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48) # torch.cuda.set_per_process_memory_fraction(0.48)

View File

@ -568,9 +568,11 @@ def gradio_advanced_training():
label="Dropout caption every n epochs", label="Dropout caption every n epochs",
value=0 value=0
) )
caption_dropout_rate = gr.Number( caption_dropout_rate = gr.Slider(
label="Rate of caption dropout", label="Rate of caption dropout",
value=0 value=0,
minimum=0,
maximum=1
) )
with gr.Row(): with gr.Row():
save_state = gr.Checkbox(label='Save training state', value=False) save_state = gr.Checkbox(label='Save training state', value=False)

View File

@ -36,7 +36,7 @@ def utilities_tab(
) )
def UI(username, password): def UI(**kwargs):
css = '' css = ''
if os.path.exists('./style.css'): if os.path.exists('./style.css'):
@ -50,10 +50,15 @@ def UI(username, password):
utilities_tab() utilities_tab()
# Show the interface # Show the interface
if not username == '': launch_kwargs={}
interface.launch(auth=(username, password)) if not kwargs.get('username', None) == '':
else: launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
interface.launch() if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)
if __name__ == '__main__': if __name__ == '__main__':
@ -65,7 +70,11 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
'--password', type=str, default='', help='Password for authentication' '--password', type=str, default='', help='Password for authentication'
) )
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
args = parser.parse_args() args = parser.parse_args()
UI(username=args.username, password=args.password) UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)

View File

@ -495,40 +495,6 @@ def train_model(
save_inference_file(output_dir, v2, v_parameterization, output_name) save_inference_file(output_dir, v2, v_parameterization, output_name)
def UI(username, password):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('LoRA'):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = lora_tab()
with gr.Tab('Utilities'):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
)
# Show the interface
if not username == '':
interface.launch(auth=(username, password))
else:
interface.launch()
def lora_tab( def lora_tab(
train_data_dir_input=gr.Textbox(), train_data_dir_input=gr.Textbox(),
reg_data_dir_input=gr.Textbox(), reg_data_dir_input=gr.Textbox(),
@ -644,7 +610,7 @@ def lora_tab(
caption_extension, caption_extension,
cache_latents, cache_latents,
) = gradio_training( ) = gradio_training(
learning_rate_value='1e-5', learning_rate_value='0.0001',
lr_scheduler_value='cosine', lr_scheduler_value='cosine',
lr_warmup_value='10', lr_warmup_value='10',
) )
@ -656,7 +622,7 @@ def lora_tab(
) )
unet_lr = gr.Textbox( unet_lr = gr.Textbox(
label='Unet learning rate', label='Unet learning rate',
value='1e-3', value='0.0001',
placeholder='Optional', placeholder='Optional',
) )
network_dim = gr.Slider( network_dim = gr.Slider(
@ -845,6 +811,45 @@ def lora_tab(
) )
def UI(**kwargs):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('LoRA'):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = lora_tab()
with gr.Tab('Utilities'):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
)
# Show the interface
launch_kwargs={}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)
if __name__ == '__main__': if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48) # torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -854,7 +859,11 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
'--password', type=str, default='', help='Password for authentication' '--password', type=str, default='', help='Password for authentication'
) )
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
args = parser.parse_args() args = parser.parse_args()
UI(username=args.username, password=args.password) UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)

View File

@ -481,40 +481,6 @@ def train_model(
save_inference_file(output_dir, v2, v_parameterization, output_name) save_inference_file(output_dir, v2, v_parameterization, output_name)
def UI(username, password):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('Dreambooth TI'):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = ti_tab()
with gr.Tab('Utilities'):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
)
# Show the interface
if not username == '':
interface.launch(auth=(username, password))
else:
interface.launch()
def ti_tab( def ti_tab(
train_data_dir=gr.Textbox(), train_data_dir=gr.Textbox(),
reg_data_dir=gr.Textbox(), reg_data_dir=gr.Textbox(),
@ -823,6 +789,45 @@ def ti_tab(
) )
def UI(**kwargs):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('Dreambooth TI'):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = ti_tab()
with gr.Tab('Utilities'):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
)
# Show the interface
launch_kwargs={}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)
if __name__ == '__main__': if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48) # torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -832,7 +837,11 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
'--password', type=str, default='', help='Password for authentication' '--password', type=str, default='', help='Password for authentication'
) )
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
args = parser.parse_args() args = parser.parse_args()
UI(username=args.username, password=args.password) UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)