diff --git a/lora_gui.py b/lora_gui.py index 8b15a94..29523a7 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -10,6 +10,7 @@ import os import subprocess import pathlib import argparse +import shutil from library.common_gui import ( get_folder_path, remove_doublequote, @@ -40,7 +41,12 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 +locon_path = os.getcwd()+'\\locon\\' +def getlocon(): + os.system('git clone https://github.com/KohakuBlueleaf/LoCon.git') + os.system('ren '+locon_path[:-6]+'\\LoCon\\'+' locon_github-sourcecode') + shutil.copytree(locon_path[:-6]+'locon_github-sourcecode\\locon\\', locon_path) def save_configuration( save_as, @@ -102,7 +108,7 @@ def save_configuration( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args,noise_offset, + optimizer_args,noise_offset, locon = 0 ): # Get list of function parameters and values parameters = list(locals().items()) @@ -208,7 +214,7 @@ def open_configuration( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args,noise_offset, + optimizer_args,noise_offset, locon=0 ): # Get list of function parameters and values parameters = list(locals().items()) @@ -292,7 +298,7 @@ def train_model( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args,noise_offset, + optimizer_args,noise_offset, locon ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -428,7 +434,12 @@ def train_model( run_cmd += f' --save_model_as={save_model_as}' if not float(prior_loss_weight) == 1.0: run_cmd += f' --prior_loss_weight={prior_loss_weight}' - run_cmd += f' --network_module=networks.lora' + if locon: + if not os.path.exists(locon_path): + getlocon() + run_cmd += ' --network_module=locon.locon_kohya' + else: + run_cmd += f' --network_module=networks.lora' if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0): if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0): @@ -676,6 +687,8 @@ def lora_tab( ) enable_bucket = gr.Checkbox(label='Enable buckets', value=True) with gr.Accordion('Advanced Configuration', open=False): + with gr.Row(): + locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA', value=False) with gr.Row(): no_token_padding = gr.Checkbox( label='No token padding', value=False @@ -805,7 +818,7 @@ def lora_tab( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args,noise_offset, + optimizer_args,noise_offset,locon ] button_open_config.click(