Merge branch 'LoCon' into dev

This commit is contained in:
bmaltais 2023-03-02 07:53:01 -05:00 committed by GitHub
commit 11b3955032
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -46,6 +46,19 @@ folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾 save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄 document_symbol = '\U0001F4C4' # 📄
path_of_this_folder = os.getcwd()
def getlocon(existance):
now_path = os.getcwd()
if existance:
print('Checking LoCon script version...')
os.chdir(os.path.join(path_of_this_folder, 'locon'))
os.system('git pull')
os.chdir(now_path)
else:
os.chdir(path_of_this_folder)
os.system('git clone https://github.com/KohakuBlueleaf/LoCon.git locon')
os.chdir(now_path)
def save_configuration( def save_configuration(
@ -109,8 +122,8 @@ def save_configuration(
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_dropout_rate, caption_dropout_rate,
optimizer, optimizer,
optimizer_args, optimizer_args,noise_offset,
noise_offset, locon=0, conv_dim=0, conv_alpha=0,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -217,8 +230,8 @@ def open_configuration(
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_dropout_rate, caption_dropout_rate,
optimizer, optimizer,
optimizer_args, optimizer_args,noise_offset,
noise_offset, locon=0, conv_dim=0, conv_alpha=0,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -305,9 +318,9 @@ def train_model(
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_dropout_rate, caption_dropout_rate,
optimizer, optimizer,
optimizer_args, optimizer_args,noise_offset,
noise_offset, locon, conv_dim, conv_alpha,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
return return
@ -442,7 +455,12 @@ def train_model(
run_cmd += f' --save_model_as={save_model_as}' run_cmd += f' --save_model_as={save_model_as}'
if not float(prior_loss_weight) == 1.0: if not float(prior_loss_weight) == 1.0:
run_cmd += f' --prior_loss_weight={prior_loss_weight}' run_cmd += f' --prior_loss_weight={prior_loss_weight}'
run_cmd += f' --network_module=networks.lora' if locon:
getlocon(os.path.exists(os.path.join(path_of_this_folder, 'locon')))
run_cmd += f' --network_module=locon.locon.locon_kohya'
run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"'
else:
run_cmd += f' --network_module=networks.lora'
if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0): if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0): if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0):
@ -660,19 +678,19 @@ def lora_tab(
placeholder='Optional', placeholder='Optional',
) )
network_dim = gr.Slider( network_dim = gr.Slider(
minimum=4, minimum=1,
maximum=1024, maximum=1024,
label='Network Rank (Dimension)', label='Network Rank (Dimension)',
value=8, value=8,
step=4, step=1,
interactive=True, interactive=True,
) )
network_alpha = gr.Slider( network_alpha = gr.Slider(
minimum=4, minimum=1,
maximum=1024, maximum=1024,
label='Network Alpha', label='Network Alpha',
value=1, value=1,
step=4, step=1,
interactive=True, interactive=True,
) )
with gr.Row(): with gr.Row():
@ -690,6 +708,22 @@ def lora_tab(
) )
enable_bucket = gr.Checkbox(label='Enable buckets', value=True) enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
with gr.Accordion('Advanced Configuration', open=False): with gr.Accordion('Advanced Configuration', open=False):
with gr.Row():
locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False)
conv_dim = gr.Slider(
minimum=1,
maximum=512,
value=1,
step=1,
label='LoCon Convolution Rank (Dimension)',
)
conv_alpha = gr.Slider(
minimum=1,
maximum=512,
value=1,
step=1,
label='LoCon Convolution Alpha',
)
with gr.Row(): with gr.Row():
no_token_padding = gr.Checkbox( no_token_padding = gr.Checkbox(
label='No token padding', value=False label='No token padding', value=False
@ -834,8 +868,8 @@ def lora_tab(
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_dropout_rate, caption_dropout_rate,
optimizer, optimizer,
optimizer_args, optimizer_args,noise_offset,
noise_offset, locon, conv_dim, conv_alpha,
] ]
button_open_config.click( button_open_config.click(