Merge pull request #273 from bmaltais/dev

Upgrade Gradio
This commit is contained in:
bmaltais 2023-03-02 20:39:33 -05:00 committed by GitHub
commit 76acf2a200
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 67 additions and 39 deletions

View File

@ -355,6 +355,7 @@ def gradio_source_model():
pretrained_model_name_or_path = gr.Textbox( pretrained_model_name_or_path = gr.Textbox(
label='Pretrained model name or path', label='Pretrained model name or path',
placeholder='enter the path to custom model or name of pretrained model', placeholder='enter the path to custom model or name of pretrained model',
value='runwayml/stable-diffusion-v1-5'
) )
pretrained_model_name_or_path_file = gr.Button( pretrained_model_name_or_path_file = gr.Button(
document_symbol, elem_id='open_folder_small' document_symbol, elem_id='open_folder_small'
@ -373,7 +374,7 @@ def gradio_source_model():
outputs=pretrained_model_name_or_path, outputs=pretrained_model_name_or_path,
) )
model_list = gr.Dropdown( model_list = gr.Dropdown(
label='(Optional) Model Quick Pick', label='Model Quick Pick',
choices=[ choices=[
'custom', 'custom',
'stabilityai/stable-diffusion-2-1-base', 'stabilityai/stable-diffusion-2-1-base',
@ -383,6 +384,7 @@ def gradio_source_model():
'runwayml/stable-diffusion-v1-5', 'runwayml/stable-diffusion-v1-5',
'CompVis/stable-diffusion-v1-4', 'CompVis/stable-diffusion-v1-4',
], ],
value='runwayml/stable-diffusion-v1-5'
) )
save_model_as = gr.Dropdown( save_model_as = gr.Dropdown(
label='Save trained model as', label='Save trained model as',
@ -397,7 +399,7 @@ def gradio_source_model():
) )
with gr.Row(): with gr.Row():
v2 = gr.Checkbox(label='v2', value=True) v2 = gr.Checkbox(label='v2', value=False)
v_parameterization = gr.Checkbox( v_parameterization = gr.Checkbox(
label='v_parameterization', value=False label='v_parameterization', value=False
) )

View File

@ -123,7 +123,7 @@ def save_configuration(
caption_dropout_rate, caption_dropout_rate,
optimizer, optimizer,
optimizer_args,noise_offset, optimizer_args,noise_offset,
locon=0, conv_dim=0, conv_alpha=0, LoRA_type='Standard', conv_dim=0, conv_alpha=0,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -231,7 +231,7 @@ def open_configuration(
caption_dropout_rate, caption_dropout_rate,
optimizer, optimizer,
optimizer_args,noise_offset, optimizer_args,noise_offset,
locon=0, conv_dim=0, conv_alpha=0, LoRA_type='Standard', conv_dim=0, conv_alpha=0,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -256,6 +256,12 @@ def open_configuration(
if not key in ['file_path']: if not key in ['file_path']:
values.append(my_data.get(key, value)) values.append(my_data.get(key, value))
# This next section is about making the LoCon parameters visible if LoRA_type = 'Standard'
if my_data.get('LoRA_type', 'Standard') == 'LoCon':
values.append(gr.Group.update(visible=True))
else:
values.append(gr.Group.update(visible=False))
return tuple(values) return tuple(values)
@ -319,7 +325,7 @@ def train_model(
caption_dropout_rate, caption_dropout_rate,
optimizer, optimizer,
optimizer_args,noise_offset, optimizer_args,noise_offset,
locon, conv_dim, conv_alpha, LoRA_type, conv_dim, conv_alpha,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -455,7 +461,7 @@ def train_model(
run_cmd += f' --save_model_as={save_model_as}' run_cmd += f' --save_model_as={save_model_as}'
if not float(prior_loss_weight) == 1.0: if not float(prior_loss_weight) == 1.0:
run_cmd += f' --prior_loss_weight={prior_loss_weight}' run_cmd += f' --prior_loss_weight={prior_loss_weight}'
if locon: if LoRA_type == 'LoCon':
getlocon(os.path.exists(os.path.join(path_of_this_folder, 'locon'))) getlocon(os.path.exists(os.path.join(path_of_this_folder, 'locon')))
run_cmd += f' --network_module=locon.locon.locon_kohya' run_cmd += f' --network_module=locon.locon.locon_kohya'
run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"' run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"'
@ -634,6 +640,14 @@ def lora_tab(
) )
with gr.Tab('Training parameters'): with gr.Tab('Training parameters'):
with gr.Row(): with gr.Row():
LoRA_type = gr.Dropdown(
label='LoRA type',
choices=[
'Standard',
'LoCon',
],
value='Standard'
)
lora_network_weights = gr.Textbox( lora_network_weights = gr.Textbox(
label='LoRA network weights', label='LoRA network weights',
placeholder='{Optional) Path to existing LoRA network weights to resume training', placeholder='{Optional) Path to existing LoRA network weights to resume training',
@ -666,6 +680,7 @@ def lora_tab(
lr_scheduler_value='cosine', lr_scheduler_value='cosine',
lr_warmup_value='10', lr_warmup_value='10',
) )
with gr.Row(): with gr.Row():
text_encoder_lr = gr.Textbox( text_encoder_lr = gr.Textbox(
label='Text Encoder learning rate', label='Text Encoder learning rate',
@ -693,6 +708,33 @@ def lora_tab(
step=1, step=1,
interactive=True, interactive=True,
) )
with gr.Group(visible=False) as LoCon_group:
def LoRA_type_change(LoRA_type):
if LoRA_type == "LoCon":
return gr.Group.update(visible=True)
else:
return gr.Group.update(visible=False)
with gr.Row():
# locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False)
conv_dim = gr.Slider(
minimum=1,
maximum=512,
value=1,
step=1,
label='LoCon Convolution Rank (Dimension)',
)
conv_alpha = gr.Slider(
minimum=1,
maximum=512,
value=1,
step=1,
label='LoCon Convolution Alpha',
)
# Show of hide LoCon conv settings depending on LoRA type selection
LoRA_type.change(LoRA_type_change, inputs=[LoRA_type], outputs=[LoCon_group])
with gr.Row(): with gr.Row():
max_resolution = gr.Textbox( max_resolution = gr.Textbox(
label='Max resolution', label='Max resolution',
@ -708,22 +750,6 @@ def lora_tab(
) )
enable_bucket = gr.Checkbox(label='Enable buckets', value=True) enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
with gr.Accordion('Advanced Configuration', open=False): with gr.Accordion('Advanced Configuration', open=False):
with gr.Row():
locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False)
conv_dim = gr.Slider(
minimum=1,
maximum=512,
value=1,
step=1,
label='LoCon Convolution Rank (Dimension)',
)
conv_alpha = gr.Slider(
minimum=1,
maximum=512,
value=1,
step=1,
label='LoCon Convolution Alpha',
)
with gr.Row(): with gr.Row():
no_token_padding = gr.Checkbox( no_token_padding = gr.Checkbox(
label='No token padding', value=False label='No token padding', value=False
@ -869,13 +895,13 @@ def lora_tab(
caption_dropout_rate, caption_dropout_rate,
optimizer, optimizer,
optimizer_args,noise_offset, optimizer_args,noise_offset,
locon, conv_dim, conv_alpha, LoRA_type, conv_dim, conv_alpha,
] ]
button_open_config.click( button_open_config.click(
open_configuration, open_configuration,
inputs=[config_file_name] + settings_list, inputs=[config_file_name] + settings_list,
outputs=[config_file_name] + settings_list, outputs=[config_file_name] + settings_list + [LoCon_group],
) )
button_save_config.click( button_save_config.click(

View File

@ -1,28 +1,28 @@
accelerate==0.15.0 accelerate==0.15.0
transformers==4.26.0
ftfy==6.1.1
albumentations==1.3.0 albumentations==1.3.0
opencv-python==4.7.0.68
einops==0.6.0
diffusers[torch]==0.10.2
pytorch-lightning==1.9.0
bitsandbytes==0.35.0
tensorboard==2.10.1
safetensors==0.2.6
gradio==3.16.2
altair==4.2.2 altair==4.2.2
easygui==0.98.3 bitsandbytes==0.35.0
tk==0.1.0
lion-pytorch==0.0.6
dadaptation==1.5 dadaptation==1.5
diffusers[torch]==0.10.2
easygui==0.98.3
einops==0.6.0
ftfy==6.1.1
gradio==3.19.1
lion-pytorch==0.0.6
opencv-python==4.7.0.68
pytorch-lightning==1.9.0
safetensors==0.2.6
tensorboard==2.10.1
tk==0.1.0
transformers==4.26.0
# for BLIP captioning # for BLIP captioning
fairscale==0.4.13
requests==2.28.2 requests==2.28.2
timm==0.6.12 timm==0.6.12
fairscale==0.4.13
# for WD14 captioning # for WD14 captioning
# tensorflow<2.11 # tensorflow<2.11
tensorflow==2.10.1
huggingface-hub==0.12.0 huggingface-hub==0.12.0
tensorflow==2.10.1
# xformers @ https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl # xformers @ https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
# for kohya_ss library # for kohya_ss library
. .