From b1fb87a9e17dfb760b99dfd0143bb25097d54b7c Mon Sep 17 00:00:00 2001 From: bmaltais Date: Tue, 28 Feb 2023 07:45:42 -0500 Subject: [PATCH] Merging PR into LoCon branch --- lora_gui.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/lora_gui.py b/lora_gui.py index eb0b567..e79d067 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -10,6 +10,7 @@ import os import subprocess import pathlib import argparse +import shutil from library.common_gui import ( get_folder_path, remove_doublequote, @@ -43,6 +44,12 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 +locon_path = os.getcwd()+'\\locon\\' + +def getlocon(): + os.system('git clone https://github.com/KohakuBlueleaf/LoCon.git') + os.system('ren '+locon_path[:-6]+'\\LoCon\\'+' locon_github-sourcecode') + shutil.copytree(locon_path[:-6]+'locon_github-sourcecode\\locon\\', locon_path) def save_configuration( save_as, @@ -104,7 +111,7 @@ def save_configuration( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args,noise_offset, + optimizer_args,noise_offset, locon = 0 ): # Get list of function parameters and values parameters = list(locals().items()) @@ -210,7 +217,7 @@ def open_configuration( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args,noise_offset, + optimizer_args,noise_offset, locon=0 ): # Get list of function parameters and values parameters = list(locals().items()) @@ -294,7 +301,7 @@ def train_model( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args,noise_offset, + optimizer_args,noise_offset, locon ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -430,7 +437,12 @@ def train_model( run_cmd += f' --save_model_as={save_model_as}' if not float(prior_loss_weight) == 1.0: run_cmd += f' --prior_loss_weight={prior_loss_weight}' - run_cmd += f' --network_module=networks.lora' + if locon: + if not os.path.exists(locon_path): + getlocon() + run_cmd += ' --network_module=locon.locon_kohya' + else: + run_cmd += f' --network_module=networks.lora' if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0): if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0): @@ -678,6 +690,8 @@ def lora_tab( ) enable_bucket = gr.Checkbox(label='Enable buckets', value=True) with gr.Accordion('Advanced Configuration', open=False): + with gr.Row(): + locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (may not be able to merge now)', value=False) with gr.Row(): no_token_padding = gr.Checkbox( label='No token padding', value=False @@ -819,7 +833,7 @@ def lora_tab( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args,noise_offset, + optimizer_args,noise_offset,locon ] button_open_config.click(