Merge branch 'LoCon' into dev
This commit is contained in:
commit
11b3955032
58
lora_gui.py
58
lora_gui.py
@ -46,6 +46,19 @@ folder_symbol = '\U0001f4c2' # 📂
|
|||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
save_style_symbol = '\U0001f4be' # 💾
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
document_symbol = '\U0001F4C4' # 📄
|
document_symbol = '\U0001F4C4' # 📄
|
||||||
|
path_of_this_folder = os.getcwd()
|
||||||
|
|
||||||
|
def getlocon(existance):
|
||||||
|
now_path = os.getcwd()
|
||||||
|
if existance:
|
||||||
|
print('Checking LoCon script version...')
|
||||||
|
os.chdir(os.path.join(path_of_this_folder, 'locon'))
|
||||||
|
os.system('git pull')
|
||||||
|
os.chdir(now_path)
|
||||||
|
else:
|
||||||
|
os.chdir(path_of_this_folder)
|
||||||
|
os.system('git clone https://github.com/KohakuBlueleaf/LoCon.git locon')
|
||||||
|
os.chdir(now_path)
|
||||||
|
|
||||||
|
|
||||||
def save_configuration(
|
def save_configuration(
|
||||||
@ -109,8 +122,8 @@ def save_configuration(
|
|||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,noise_offset,
|
||||||
noise_offset,
|
locon=0, conv_dim=0, conv_alpha=0,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -217,8 +230,8 @@ def open_configuration(
|
|||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,noise_offset,
|
||||||
noise_offset,
|
locon=0, conv_dim=0, conv_alpha=0,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -305,8 +318,8 @@ def train_model(
|
|||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,noise_offset,
|
||||||
noise_offset,
|
locon, conv_dim, conv_alpha,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -442,6 +455,11 @@ def train_model(
|
|||||||
run_cmd += f' --save_model_as={save_model_as}'
|
run_cmd += f' --save_model_as={save_model_as}'
|
||||||
if not float(prior_loss_weight) == 1.0:
|
if not float(prior_loss_weight) == 1.0:
|
||||||
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
||||||
|
if locon:
|
||||||
|
getlocon(os.path.exists(os.path.join(path_of_this_folder, 'locon')))
|
||||||
|
run_cmd += f' --network_module=locon.locon.locon_kohya'
|
||||||
|
run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"'
|
||||||
|
else:
|
||||||
run_cmd += f' --network_module=networks.lora'
|
run_cmd += f' --network_module=networks.lora'
|
||||||
|
|
||||||
if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
|
if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
|
||||||
@ -660,19 +678,19 @@ def lora_tab(
|
|||||||
placeholder='Optional',
|
placeholder='Optional',
|
||||||
)
|
)
|
||||||
network_dim = gr.Slider(
|
network_dim = gr.Slider(
|
||||||
minimum=4,
|
minimum=1,
|
||||||
maximum=1024,
|
maximum=1024,
|
||||||
label='Network Rank (Dimension)',
|
label='Network Rank (Dimension)',
|
||||||
value=8,
|
value=8,
|
||||||
step=4,
|
step=1,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
network_alpha = gr.Slider(
|
network_alpha = gr.Slider(
|
||||||
minimum=4,
|
minimum=1,
|
||||||
maximum=1024,
|
maximum=1024,
|
||||||
label='Network Alpha',
|
label='Network Alpha',
|
||||||
value=1,
|
value=1,
|
||||||
step=4,
|
step=1,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -690,6 +708,22 @@ def lora_tab(
|
|||||||
)
|
)
|
||||||
enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
|
enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
|
||||||
with gr.Accordion('Advanced Configuration', open=False):
|
with gr.Accordion('Advanced Configuration', open=False):
|
||||||
|
with gr.Row():
|
||||||
|
locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False)
|
||||||
|
conv_dim = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=512,
|
||||||
|
value=1,
|
||||||
|
step=1,
|
||||||
|
label='LoCon Convolution Rank (Dimension)',
|
||||||
|
)
|
||||||
|
conv_alpha = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=512,
|
||||||
|
value=1,
|
||||||
|
step=1,
|
||||||
|
label='LoCon Convolution Alpha',
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
no_token_padding = gr.Checkbox(
|
no_token_padding = gr.Checkbox(
|
||||||
label='No token padding', value=False
|
label='No token padding', value=False
|
||||||
@ -834,8 +868,8 @@ def lora_tab(
|
|||||||
caption_dropout_every_n_epochs,
|
caption_dropout_every_n_epochs,
|
||||||
caption_dropout_rate,
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,
|
optimizer_args,noise_offset,
|
||||||
noise_offset,
|
locon, conv_dim, conv_alpha,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
Loading…
Reference in New Issue
Block a user