Merge pull request #251 from Ki-wimon/master
LoCon script auto upgrade feature
This commit is contained in:
commit
384009d4eb
60
lora_gui.py
60
lora_gui.py
@ -10,7 +10,6 @@ import os
|
||||
import subprocess
|
||||
import pathlib
|
||||
import argparse
|
||||
import shutil
|
||||
from library.common_gui import (
|
||||
get_folder_path,
|
||||
remove_doublequote,
|
||||
@ -44,12 +43,19 @@ folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
locon_path = os.getcwd()+'\\locon\\'
|
||||
path_of_this_folder = os.getcwd()
|
||||
|
||||
def getlocon():
|
||||
os.system('git clone https://github.com/KohakuBlueleaf/LoCon.git')
|
||||
os.system('ren '+locon_path[:-6]+'\\LoCon\\'+' locon_github-sourcecode')
|
||||
shutil.copytree(locon_path[:-6]+'locon_github-sourcecode\\locon\\', locon_path)
|
||||
def getlocon(existance):
|
||||
now_path = os.getcwd()
|
||||
if existance:
|
||||
print('Checking LoCon script version...')
|
||||
os.chdir(os.path.join(path_of_this_folder, 'locon'))
|
||||
os.system('git pull')
|
||||
os.chdir(now_path)
|
||||
else:
|
||||
os.chdir(path_of_this_folder)
|
||||
os.system('git clone https://github.com/KohakuBlueleaf/LoCon.git locon')
|
||||
os.chdir(now_path)
|
||||
|
||||
def save_configuration(
|
||||
save_as,
|
||||
@ -111,7 +117,8 @@ def save_configuration(
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
optimizer_args,noise_offset, locon = 0
|
||||
optimizer_args,noise_offset,
|
||||
locon=0, conv_dim=0, conv_alpha=0,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -217,7 +224,8 @@ def open_configuration(
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
optimizer_args,noise_offset, locon=0
|
||||
optimizer_args,noise_offset,
|
||||
locon=0, conv_dim=0, conv_alpha=0,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -301,7 +309,8 @@ def train_model(
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
optimizer_args,noise_offset, locon
|
||||
optimizer_args,noise_offset,
|
||||
locon, conv_dim, conv_alpha,
|
||||
):
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
@ -438,9 +447,9 @@ def train_model(
|
||||
if not float(prior_loss_weight) == 1.0:
|
||||
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
||||
if locon:
|
||||
if not os.path.exists(locon_path):
|
||||
getlocon()
|
||||
run_cmd += ' --network_module=locon.locon_kohya'
|
||||
getlocon(os.path.exists(os.path.join(path_of_this_folder, 'locon')))
|
||||
run_cmd += f' --network_module=locon.locon.locon_kohya'
|
||||
run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"'
|
||||
else:
|
||||
run_cmd += f' --network_module=networks.lora'
|
||||
|
||||
@ -660,19 +669,19 @@ def lora_tab(
|
||||
placeholder='Optional',
|
||||
)
|
||||
network_dim = gr.Slider(
|
||||
minimum=4,
|
||||
minimum=1,
|
||||
maximum=1024,
|
||||
label='Network Rank (Dimension)',
|
||||
value=8,
|
||||
step=4,
|
||||
step=1,
|
||||
interactive=True,
|
||||
)
|
||||
network_alpha = gr.Slider(
|
||||
minimum=4,
|
||||
minimum=1,
|
||||
maximum=1024,
|
||||
label='Network Alpha',
|
||||
value=1,
|
||||
step=4,
|
||||
step=1,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
@ -691,7 +700,21 @@ def lora_tab(
|
||||
enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
|
||||
with gr.Accordion('Advanced Configuration', open=False):
|
||||
with gr.Row():
|
||||
locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (may not be able to merge now)', value=False)
|
||||
locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False)
|
||||
conv_dim = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=512,
|
||||
value=1,
|
||||
step=1,
|
||||
label='LoCon Convolution Rank (Dimension)',
|
||||
)
|
||||
conv_alpha = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=512,
|
||||
value=1,
|
||||
step=1,
|
||||
label='LoCon Convolution Alpha',
|
||||
)
|
||||
with gr.Row():
|
||||
no_token_padding = gr.Checkbox(
|
||||
label='No token padding', value=False
|
||||
@ -833,7 +856,8 @@ def lora_tab(
|
||||
bucket_reso_steps,
|
||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||
optimizer,
|
||||
optimizer_args,noise_offset,locon
|
||||
optimizer_args,noise_offset,
|
||||
locon, conv_dim, conv_alpha,
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
|
Loading…
Reference in New Issue
Block a user