diff --git a/lora_gui.py b/lora_gui.py index 4ca9d1c..ae703a2 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -46,6 +46,19 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 +path_of_this_folder = os.getcwd() + +def getlocon(existance): + now_path = os.getcwd() + if existance: + print('Checking LoCon script version...') + os.chdir(os.path.join(path_of_this_folder, 'locon')) + os.system('git pull') + os.chdir(now_path) + else: + os.chdir(path_of_this_folder) + os.system('git clone https://github.com/KohakuBlueleaf/LoCon.git locon') + os.chdir(now_path) def save_configuration( @@ -109,8 +122,8 @@ def save_configuration( caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args, - noise_offset, + optimizer_args,noise_offset, + locon=0, conv_dim=0, conv_alpha=0, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -217,8 +230,8 @@ def open_configuration( caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args, - noise_offset, + optimizer_args,noise_offset, + locon=0, conv_dim=0, conv_alpha=0, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -305,9 +318,9 @@ def train_model( caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args, - noise_offset, -): + optimizer_args,noise_offset, + locon, conv_dim, conv_alpha, +): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') return @@ -442,7 +455,12 @@ def train_model( run_cmd += f' --save_model_as={save_model_as}' if not float(prior_loss_weight) == 1.0: run_cmd += f' --prior_loss_weight={prior_loss_weight}' - run_cmd += f' --network_module=networks.lora' + if locon: + getlocon(os.path.exists(os.path.join(path_of_this_folder, 'locon'))) + run_cmd += f' --network_module=locon.locon.locon_kohya' + run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"' + else: + run_cmd += f' --network_module=networks.lora' if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0): if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0): @@ -660,19 +678,19 @@ def lora_tab( placeholder='Optional', ) network_dim = gr.Slider( - minimum=4, + minimum=1, maximum=1024, label='Network Rank (Dimension)', value=8, - step=4, + step=1, interactive=True, ) network_alpha = gr.Slider( - minimum=4, + minimum=1, maximum=1024, label='Network Alpha', value=1, - step=4, + step=1, interactive=True, ) with gr.Row(): @@ -690,6 +708,22 @@ def lora_tab( ) enable_bucket = gr.Checkbox(label='Enable buckets', value=True) with gr.Accordion('Advanced Configuration', open=False): + with gr.Row(): + locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False) + conv_dim = gr.Slider( + minimum=1, + maximum=512, + value=1, + step=1, + label='LoCon Convolution Rank (Dimension)', + ) + conv_alpha = gr.Slider( + minimum=1, + maximum=512, + value=1, + step=1, + label='LoCon Convolution Alpha', + ) with gr.Row(): no_token_padding = gr.Checkbox( label='No token padding', value=False @@ -834,8 +868,8 @@ def lora_tab( caption_dropout_every_n_epochs, caption_dropout_rate, optimizer, - optimizer_args, - noise_offset, + optimizer_args,noise_offset, + locon, conv_dim, conv_alpha, ] button_open_config.click(