add new LoCon args

This commit is contained in:
Ki-wimon 2023-03-01 12:19:18 +08:00 committed by GitHub
parent d76fe7d4e0
commit c07e3bba76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -445,6 +445,7 @@ def train_model(
if locon: if locon:
getlocon(os.path.exists(os.path.join(path_of_this_folder, 'locon'))) getlocon(os.path.exists(os.path.join(path_of_this_folder, 'locon')))
run_cmd += f' --network_module=locon.locon.locon_kohya' run_cmd += f' --network_module=locon.locon.locon_kohya'
run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"'
else: else:
run_cmd += f' --network_module=networks.lora' run_cmd += f' --network_module=networks.lora'
@ -664,19 +665,19 @@ def lora_tab(
placeholder='Optional', placeholder='Optional',
) )
network_dim = gr.Slider( network_dim = gr.Slider(
minimum=4, minimum=1,
maximum=1024, maximum=1024,
label='Network Rank (Dimension)', label='Network Rank (Dimension)',
value=8, value=8,
step=4, step=1,
interactive=True, interactive=True,
) )
network_alpha = gr.Slider( network_alpha = gr.Slider(
minimum=4, minimum=1,
maximum=1024, maximum=1024,
label='Network Alpha', label='Network Alpha',
value=1, value=1,
step=4, step=1,
interactive=True, interactive=True,
) )
with gr.Row(): with gr.Row():
@ -696,6 +697,20 @@ def lora_tab(
with gr.Accordion('Advanced Configuration', open=False): with gr.Accordion('Advanced Configuration', open=False):
with gr.Row(): with gr.Row():
locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (may not be able to merge now)', value=False) locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (may not be able to merge now)', value=False)
conv_dim = gr.Slider(
minimum=1,
maximum=512,
value=0,
step=1,
label='LoCon Convolution Rank (Dimension)',
)
conv_alpha = gr.Slider(
minimum=1,
maximum=512,
value=0,
step=1,
label='LoCon Convolution Alpha',
)
with gr.Row(): with gr.Row():
no_token_padding = gr.Checkbox( no_token_padding = gr.Checkbox(
label='No token padding', value=False label='No token padding', value=False