commit
6edc53ae3e
21
.github/workflows/typos.yaml
vendored
Normal file
21
.github/workflows/typos.yaml
vendored
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
---
|
||||||
|
# yamllint disable rule:line-length
|
||||||
|
name: Typos
|
||||||
|
|
||||||
|
on: # yamllint disable-line rule:truthy
|
||||||
|
push:
|
||||||
|
pull_request:
|
||||||
|
types:
|
||||||
|
- opened
|
||||||
|
- synchronize
|
||||||
|
- reopened
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: typos-action
|
||||||
|
uses: crate-ci/typos@v1.13.10
|
@ -163,6 +163,8 @@ This will store your a backup file with your current locally installed pip packa
|
|||||||
|
|
||||||
## Change History
|
## Change History
|
||||||
|
|
||||||
|
* 2023/03/02 (v21.1.0):
|
||||||
|
- Add LoCon support (https://github.com/KohakuBlueleaf/LoCon.git) to the Dreambooth LoRA tab. This will allow to create a new type of LoRA that include conv layers as part of the LoRA... hence the name LoCon. LoCon will work with the native Auto1111 implementation of LoRA. If you want to use it with the Kohya_ss additionalNetwork you will need to install this other extension... until Kohya_ss support it nativelly: https://github.com/KohakuBlueleaf/a1111-sd-webui-locon
|
||||||
* 2023/03/01 (v21.0.1):
|
* 2023/03/01 (v21.0.1):
|
||||||
- Add warning to tensorboard start if the log information is missing
|
- Add warning to tensorboard start if the log information is missing
|
||||||
- Fix issue with 8bitadam on older config file load
|
- Fix issue with 8bitadam on older config file load
|
||||||
|
15
_typos.toml
Normal file
15
_typos.toml
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# Files for typos
|
||||||
|
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
|
||||||
|
|
||||||
|
[default.extend-identifiers]
|
||||||
|
|
||||||
|
[default.extend-words]
|
||||||
|
NIN="NIN"
|
||||||
|
parms="parms"
|
||||||
|
nin="nin"
|
||||||
|
extention="extention" # Intentionally left
|
||||||
|
nd="nd"
|
||||||
|
|
||||||
|
|
||||||
|
[files]
|
||||||
|
extend-exclude = ["_typos.toml"]
|
@ -95,9 +95,11 @@ def save_configuration(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,noise_offset,
|
optimizer_args,
|
||||||
|
noise_offset,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -194,9 +196,11 @@ def open_configuration(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,noise_offset,
|
optimizer_args,
|
||||||
|
noise_offset,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -272,9 +276,11 @@ def train_model(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,noise_offset,
|
optimizer_args,
|
||||||
|
noise_offset,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -566,7 +572,8 @@ def dreambooth_tab(
|
|||||||
seed,
|
seed,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
optimizer,optimizer_args,
|
optimizer,
|
||||||
|
optimizer_args,
|
||||||
) = gradio_training(
|
) = gradio_training(
|
||||||
learning_rate_value='1e-5',
|
learning_rate_value='1e-5',
|
||||||
lr_scheduler_value='cosine',
|
lr_scheduler_value='cosine',
|
||||||
@ -624,7 +631,9 @@ def dreambooth_tab(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,noise_offset,
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_dropout_rate,
|
||||||
|
noise_offset,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -648,15 +657,15 @@ def dreambooth_tab(
|
|||||||
)
|
)
|
||||||
|
|
||||||
button_run = gr.Button('Train model', variant='primary')
|
button_run = gr.Button('Train model', variant='primary')
|
||||||
|
|
||||||
# Setup gradio tensorboard buttons
|
# Setup gradio tensorboard buttons
|
||||||
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
|
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
|
||||||
|
|
||||||
button_start_tensorboard.click(
|
button_start_tensorboard.click(
|
||||||
start_tensorboard,
|
start_tensorboard,
|
||||||
inputs=logging_dir,
|
inputs=logging_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
button_stop_tensorboard.click(
|
button_stop_tensorboard.click(
|
||||||
stop_tensorboard,
|
stop_tensorboard,
|
||||||
)
|
)
|
||||||
@ -710,8 +719,11 @@ def dreambooth_tab(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs,
|
||||||
optimizer,optimizer_args,noise_offset,
|
caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
|
optimizer_args,
|
||||||
|
noise_offset,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
@ -773,16 +785,20 @@ def UI(**kwargs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Show the interface
|
# Show the interface
|
||||||
launch_kwargs={}
|
launch_kwargs = {}
|
||||||
if not kwargs.get('username', None) == '':
|
if not kwargs.get('username', None) == '':
|
||||||
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
|
launch_kwargs['auth'] = (
|
||||||
|
kwargs.get('username', None),
|
||||||
|
kwargs.get('password', None),
|
||||||
|
)
|
||||||
if kwargs.get('server_port', 0) > 0:
|
if kwargs.get('server_port', 0) > 0:
|
||||||
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
|
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
|
||||||
if kwargs.get('inbrowser', False):
|
if kwargs.get('inbrowser', False):
|
||||||
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
|
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
|
||||||
print(launch_kwargs)
|
print(launch_kwargs)
|
||||||
interface.launch(**launch_kwargs)
|
interface.launch(**launch_kwargs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -793,10 +809,20 @@ if __name__ == '__main__':
|
|||||||
'--password', type=str, default='', help='Password for authentication'
|
'--password', type=str, default='', help='Password for authentication'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--server_port', type=int, default=0, help='Port to run the server listener on'
|
'--server_port',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='Port to run the server listener on',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--inbrowser', action='store_true', help='Open in browser'
|
||||||
)
|
)
|
||||||
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
|
UI(
|
||||||
|
username=args.username,
|
||||||
|
password=args.password,
|
||||||
|
inbrowser=args.inbrowser,
|
||||||
|
server_port=args.server_port,
|
||||||
|
)
|
||||||
|
@ -91,8 +91,11 @@ def save_configuration(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs,
|
||||||
optimizer,optimizer_args,noise_offset,
|
caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
|
optimizer_args,
|
||||||
|
noise_offset,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -195,8 +198,11 @@ def open_config_file(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs,
|
||||||
optimizer,optimizer_args,noise_offset,
|
caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
|
optimizer_args,
|
||||||
|
noise_offset,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -278,8 +284,11 @@ def train_model(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs,
|
||||||
optimizer,optimizer_args,noise_offset,
|
caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
|
optimizer_args,
|
||||||
|
noise_offset,
|
||||||
):
|
):
|
||||||
# create caption json file
|
# create caption json file
|
||||||
if generate_caption_database:
|
if generate_caption_database:
|
||||||
@ -585,7 +594,8 @@ def finetune_tab():
|
|||||||
seed,
|
seed,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
optimizer,optimizer_args,
|
optimizer,
|
||||||
|
optimizer_args,
|
||||||
) = gradio_training(learning_rate_value='1e-5')
|
) = gradio_training(learning_rate_value='1e-5')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
dataset_repeats = gr.Textbox(label='Dataset repeats', value=40)
|
dataset_repeats = gr.Textbox(label='Dataset repeats', value=40)
|
||||||
@ -617,7 +627,9 @@ def finetune_tab():
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,noise_offset,
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_dropout_rate,
|
||||||
|
noise_offset,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -631,15 +643,15 @@ def finetune_tab():
|
|||||||
)
|
)
|
||||||
|
|
||||||
button_run = gr.Button('Train model', variant='primary')
|
button_run = gr.Button('Train model', variant='primary')
|
||||||
|
|
||||||
# Setup gradio tensorboard buttons
|
# Setup gradio tensorboard buttons
|
||||||
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
|
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
|
||||||
|
|
||||||
button_start_tensorboard.click(
|
button_start_tensorboard.click(
|
||||||
start_tensorboard,
|
start_tensorboard,
|
||||||
inputs=logging_dir,
|
inputs=logging_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
button_stop_tensorboard.click(
|
button_stop_tensorboard.click(
|
||||||
stop_tensorboard,
|
stop_tensorboard,
|
||||||
)
|
)
|
||||||
@ -699,8 +711,11 @@ def finetune_tab():
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs,
|
||||||
optimizer,optimizer_args,noise_offset,
|
caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
|
optimizer_args,
|
||||||
|
noise_offset,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_run.click(train_model, inputs=settings_list)
|
button_run.click(train_model, inputs=settings_list)
|
||||||
@ -742,16 +757,19 @@ def UI(**kwargs):
|
|||||||
utilities_tab(enable_dreambooth_tab=False)
|
utilities_tab(enable_dreambooth_tab=False)
|
||||||
|
|
||||||
# Show the interface
|
# Show the interface
|
||||||
launch_kwargs={}
|
launch_kwargs = {}
|
||||||
if not kwargs.get('username', None) == '':
|
if not kwargs.get('username', None) == '':
|
||||||
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
|
launch_kwargs['auth'] = (
|
||||||
|
kwargs.get('username', None),
|
||||||
|
kwargs.get('password', None),
|
||||||
|
)
|
||||||
if kwargs.get('server_port', 0) > 0:
|
if kwargs.get('server_port', 0) > 0:
|
||||||
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
|
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
|
||||||
if kwargs.get('inbrowser', False):
|
if kwargs.get('inbrowser', False):
|
||||||
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
|
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
|
||||||
print(launch_kwargs)
|
print(launch_kwargs)
|
||||||
interface.launch(**launch_kwargs)
|
interface.launch(**launch_kwargs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||||
@ -763,10 +781,20 @@ if __name__ == '__main__':
|
|||||||
'--password', type=str, default='', help='Password for authentication'
|
'--password', type=str, default='', help='Password for authentication'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--server_port', type=int, default=0, help='Port to run the server listener on'
|
'--server_port',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='Port to run the server listener on',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--inbrowser', action='store_true', help='Open in browser'
|
||||||
)
|
)
|
||||||
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
|
UI(
|
||||||
|
username=args.username,
|
||||||
|
password=args.password,
|
||||||
|
inbrowser=args.inbrowser,
|
||||||
|
server_port=args.server_port,
|
||||||
|
)
|
||||||
|
30
kohya_gui.py
30
kohya_gui.py
@ -53,15 +53,16 @@ def UI(**kwargs):
|
|||||||
inbrowser = kwargs.get('inbrowser', False)
|
inbrowser = kwargs.get('inbrowser', False)
|
||||||
share = kwargs.get('share', False)
|
share = kwargs.get('share', False)
|
||||||
if username and password:
|
if username and password:
|
||||||
launch_kwargs["auth"] = (username, password)
|
launch_kwargs['auth'] = (username, password)
|
||||||
if server_port > 0:
|
if server_port > 0:
|
||||||
launch_kwargs["server_port"] = server_port
|
launch_kwargs['server_port'] = server_port
|
||||||
if inbrowser:
|
if inbrowser:
|
||||||
launch_kwargs["inbrowser"] = inbrowser
|
launch_kwargs['inbrowser'] = inbrowser
|
||||||
if share:
|
if share:
|
||||||
launch_kwargs["share"] = share
|
launch_kwargs['share'] = share
|
||||||
interface.launch(**launch_kwargs)
|
interface.launch(**launch_kwargs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -72,11 +73,24 @@ if __name__ == '__main__':
|
|||||||
'--password', type=str, default='', help='Password for authentication'
|
'--password', type=str, default='', help='Password for authentication'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--server_port', type=int, default=0, help='Port to run the server listener on'
|
'--server_port',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='Port to run the server listener on',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--inbrowser', action='store_true', help='Open in browser'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--share', action='store_true', help='Share the gradio UI'
|
||||||
)
|
)
|
||||||
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
|
||||||
parser.add_argument("--share", action="store_true", help="Share the gradio UI")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port, share=args.share)
|
UI(
|
||||||
|
username=args.username,
|
||||||
|
password=args.password,
|
||||||
|
inbrowser=args.inbrowser,
|
||||||
|
server_port=args.server_port,
|
||||||
|
share=args.share,
|
||||||
|
)
|
||||||
|
@ -9,6 +9,7 @@ refresh_symbol = '\U0001f504' # 🔄
|
|||||||
save_style_symbol = '\U0001f4be' # 💾
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
document_symbol = '\U0001F4C4' # 📄
|
document_symbol = '\U0001F4C4' # 📄
|
||||||
|
|
||||||
|
|
||||||
def update_optimizer(my_data):
|
def update_optimizer(my_data):
|
||||||
if my_data.get('use_8bit_adam', False):
|
if my_data.get('use_8bit_adam', False):
|
||||||
my_data['optimizer'] = 'AdamW8bit'
|
my_data['optimizer'] = 'AdamW8bit'
|
||||||
@ -86,13 +87,18 @@ def remove_doublequote(file_path):
|
|||||||
|
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
def set_legacy_8bitadam(optimizer, use_8bit_adam):
|
def set_legacy_8bitadam(optimizer, use_8bit_adam):
|
||||||
if optimizer == 'AdamW8bit':
|
if optimizer == 'AdamW8bit':
|
||||||
# use_8bit_adam = True
|
# use_8bit_adam = True
|
||||||
return gr.Dropdown.update(value=optimizer), gr.Checkbox.update(value=True, interactive=False, visible=True)
|
return gr.Dropdown.update(value=optimizer), gr.Checkbox.update(
|
||||||
|
value=True, interactive=False, visible=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# use_8bit_adam = False
|
# use_8bit_adam = False
|
||||||
return gr.Dropdown.update(value=optimizer), gr.Checkbox.update(value=False, interactive=False, visible=True)
|
return gr.Dropdown.update(value=optimizer), gr.Checkbox.update(
|
||||||
|
value=False, interactive=False, visible=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_folder_path(folder_path=''):
|
def get_folder_path(folder_path=''):
|
||||||
@ -489,14 +495,15 @@ def gradio_training(
|
|||||||
'DAdaptation',
|
'DAdaptation',
|
||||||
'Lion',
|
'Lion',
|
||||||
'SGDNesterov',
|
'SGDNesterov',
|
||||||
'SGDNesterov8bit'
|
'SGDNesterov8bit',
|
||||||
],
|
],
|
||||||
value="AdamW8bit",
|
value='AdamW8bit',
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
optimizer_args = gr.Textbox(
|
optimizer_args = gr.Textbox(
|
||||||
label='Optimizer extra arguments', placeholder='(Optional) eg: relative_step=True scale_parameter=True warmup_init=True'
|
label='Optimizer extra arguments',
|
||||||
|
placeholder='(Optional) eg: relative_step=True scale_parameter=True warmup_init=True',
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
learning_rate,
|
learning_rate,
|
||||||
@ -549,11 +556,14 @@ def run_cmd_training(**kwargs):
|
|||||||
' --cache_latents' if kwargs.get('cache_latents') else '',
|
' --cache_latents' if kwargs.get('cache_latents') else '',
|
||||||
# ' --use_lion_optimizer' if kwargs.get('optimizer') == 'Lion' else '',
|
# ' --use_lion_optimizer' if kwargs.get('optimizer') == 'Lion' else '',
|
||||||
f' --optimizer_type="{kwargs.get("optimizer", "AdamW")}"',
|
f' --optimizer_type="{kwargs.get("optimizer", "AdamW")}"',
|
||||||
f' --optimizer_args {kwargs.get("optimizer_args", "")}' if not kwargs.get('optimizer_args') == '' else '',
|
f' --optimizer_args {kwargs.get("optimizer_args", "")}'
|
||||||
|
if not kwargs.get('optimizer_args') == ''
|
||||||
|
else '',
|
||||||
]
|
]
|
||||||
run_cmd = ''.join(options)
|
run_cmd = ''.join(options)
|
||||||
return run_cmd
|
return run_cmd
|
||||||
|
|
||||||
|
|
||||||
# # This function takes a dictionary of keyword arguments and returns a string that can be used to run a command-line training script
|
# # This function takes a dictionary of keyword arguments and returns a string that can be used to run a command-line training script
|
||||||
# def run_cmd_training(**kwargs):
|
# def run_cmd_training(**kwargs):
|
||||||
# arg_map = {
|
# arg_map = {
|
||||||
@ -611,7 +621,9 @@ def gradio_advanced_training():
|
|||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
# This use_8bit_adam element should be removed in a future release as it is no longer used
|
# This use_8bit_adam element should be removed in a future release as it is no longer used
|
||||||
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=False, visible=False)
|
use_8bit_adam = gr.Checkbox(
|
||||||
|
label='Use 8bit adam', value=False, visible=False
|
||||||
|
)
|
||||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||||
color_aug = gr.Checkbox(label='Color augmentation', value=False)
|
color_aug = gr.Checkbox(label='Color augmentation', value=False)
|
||||||
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||||
@ -628,17 +640,13 @@ def gradio_advanced_training():
|
|||||||
noise_offset = gr.Textbox(
|
noise_offset = gr.Textbox(
|
||||||
label='Noise offset (0 - 1)', placeholder='(Oprional) eg: 0.1'
|
label='Noise offset (0 - 1)', placeholder='(Oprional) eg: 0.1'
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
caption_dropout_every_n_epochs = gr.Number(
|
caption_dropout_every_n_epochs = gr.Number(
|
||||||
label="Dropout caption every n epochs",
|
label='Dropout caption every n epochs', value=0
|
||||||
value=0
|
|
||||||
)
|
)
|
||||||
caption_dropout_rate = gr.Slider(
|
caption_dropout_rate = gr.Slider(
|
||||||
label="Rate of caption dropout",
|
label='Rate of caption dropout', value=0, minimum=0, maximum=1
|
||||||
value=0,
|
|
||||||
minimum=0,
|
|
||||||
maximum=1
|
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save_state = gr.Checkbox(label='Save training state', value=False)
|
save_state = gr.Checkbox(label='Save training state', value=False)
|
||||||
@ -676,7 +684,9 @@ def gradio_advanced_training():
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,noise_offset,
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_dropout_rate,
|
||||||
|
noise_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -706,11 +716,9 @@ def run_cmd_advanced_training(**kwargs):
|
|||||||
f' --caption_dropout_rate="{kwargs.get("caption_dropout_rate", "")}"'
|
f' --caption_dropout_rate="{kwargs.get("caption_dropout_rate", "")}"'
|
||||||
if float(kwargs.get('caption_dropout_rate', 0)) > 0
|
if float(kwargs.get('caption_dropout_rate', 0)) > 0
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}'
|
f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}'
|
||||||
if int(kwargs.get('bucket_reso_steps', 64)) >= 1
|
if int(kwargs.get('bucket_reso_steps', 64)) >= 1
|
||||||
else '',
|
else '',
|
||||||
|
|
||||||
' --save_state' if kwargs.get('save_state') else '',
|
' --save_state' if kwargs.get('save_state') else '',
|
||||||
' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',
|
' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',
|
||||||
' --color_aug' if kwargs.get('color_aug') else '',
|
' --color_aug' if kwargs.get('color_aug') else '',
|
||||||
@ -734,6 +742,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||||||
run_cmd = ''.join(options)
|
run_cmd = ''.join(options)
|
||||||
return run_cmd
|
return run_cmd
|
||||||
|
|
||||||
|
|
||||||
# def run_cmd_advanced_training(**kwargs):
|
# def run_cmd_advanced_training(**kwargs):
|
||||||
# arg_map = {
|
# arg_map = {
|
||||||
# 'max_train_epochs': ' --max_train_epochs="{}"',
|
# 'max_train_epochs': ' --max_train_epochs="{}"',
|
||||||
@ -763,4 +772,4 @@ def run_cmd_advanced_training(**kwargs):
|
|||||||
|
|
||||||
# cmd = ''.join(options)
|
# cmd = ''.join(options)
|
||||||
|
|
||||||
# return cmd
|
# return cmd
|
||||||
|
@ -217,7 +217,7 @@ def gradio_convert_model_tab():
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
target_save_precision_type = gr.Dropdown(
|
target_save_precision_type = gr.Dropdown(
|
||||||
label='Target model precison',
|
label='Target model precision',
|
||||||
choices=['unspecified', 'fp16', 'bf16', 'float'],
|
choices=['unspecified', 'fp16', 'bf16', 'float'],
|
||||||
value='unspecified',
|
value='unspecified',
|
||||||
)
|
)
|
||||||
|
@ -115,7 +115,7 @@ def gradio_extract_lora_tab():
|
|||||||
outputs=save_to,
|
outputs=save_to,
|
||||||
)
|
)
|
||||||
save_precision = gr.Dropdown(
|
save_precision = gr.Dropdown(
|
||||||
label='Save precison',
|
label='Save precision',
|
||||||
choices=['fp16', 'bf16', 'float'],
|
choices=['fp16', 'bf16', 'float'],
|
||||||
value='float',
|
value='float',
|
||||||
interactive=True,
|
interactive=True,
|
||||||
|
@ -121,13 +121,13 @@ def gradio_merge_lora_tab():
|
|||||||
outputs=save_to,
|
outputs=save_to,
|
||||||
)
|
)
|
||||||
precision = gr.Dropdown(
|
precision = gr.Dropdown(
|
||||||
label='Merge precison',
|
label='Merge precision',
|
||||||
choices=['fp16', 'bf16', 'float'],
|
choices=['fp16', 'bf16', 'float'],
|
||||||
value='float',
|
value='float',
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
save_precision = gr.Dropdown(
|
save_precision = gr.Dropdown(
|
||||||
label='Save precison',
|
label='Save precision',
|
||||||
choices=['fp16', 'bf16', 'float'],
|
choices=['fp16', 'bf16', 'float'],
|
||||||
value='float',
|
value='float',
|
||||||
interactive=True,
|
interactive=True,
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -94,7 +94,7 @@ def gradio_resize_lora_tab():
|
|||||||
outputs=save_to,
|
outputs=save_to,
|
||||||
)
|
)
|
||||||
save_precision = gr.Dropdown(
|
save_precision = gr.Dropdown(
|
||||||
label='Save precison',
|
label='Save precision',
|
||||||
choices=['fp16', 'bf16', 'float'],
|
choices=['fp16', 'bf16', 'float'],
|
||||||
value='fp16',
|
value='fp16',
|
||||||
interactive=True,
|
interactive=True,
|
||||||
|
@ -4,43 +4,49 @@ from easygui import msgbox
|
|||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
|
|
||||||
tensorboard_proc = None # I know... bad but heh
|
tensorboard_proc = None # I know... bad but heh
|
||||||
|
|
||||||
|
|
||||||
def start_tensorboard(logging_dir):
|
def start_tensorboard(logging_dir):
|
||||||
global tensorboard_proc
|
global tensorboard_proc
|
||||||
|
|
||||||
if not os.listdir(logging_dir):
|
if not os.listdir(logging_dir):
|
||||||
print("Error: log folder is empty")
|
print('Error: log folder is empty')
|
||||||
msgbox(msg="Error: log folder is empty")
|
msgbox(msg='Error: log folder is empty')
|
||||||
return
|
return
|
||||||
|
|
||||||
run_cmd = f'tensorboard.exe --logdir "{logging_dir}"'
|
run_cmd = f'tensorboard.exe --logdir "{logging_dir}"'
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
if tensorboard_proc is not None:
|
if tensorboard_proc is not None:
|
||||||
print("Tensorboard is already running. Terminating existing process before starting new one...")
|
print(
|
||||||
|
'Tensorboard is already running. Terminating existing process before starting new one...'
|
||||||
|
)
|
||||||
stop_tensorboard()
|
stop_tensorboard()
|
||||||
|
|
||||||
# Start background process
|
# Start background process
|
||||||
print('Starting tensorboard...')
|
print('Starting tensorboard...')
|
||||||
tensorboard_proc = subprocess.Popen(run_cmd)
|
tensorboard_proc = subprocess.Popen(run_cmd)
|
||||||
|
|
||||||
# Wait for some time to allow TensorBoard to start up
|
# Wait for some time to allow TensorBoard to start up
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
# Open the TensorBoard URL in the default browser
|
# Open the TensorBoard URL in the default browser
|
||||||
print('Opening tensorboard url in browser...')
|
print('Opening tensorboard url in browser...')
|
||||||
import webbrowser
|
import webbrowser
|
||||||
|
|
||||||
webbrowser.open('http://localhost:6006')
|
webbrowser.open('http://localhost:6006')
|
||||||
|
|
||||||
|
|
||||||
def stop_tensorboard():
|
def stop_tensorboard():
|
||||||
print('Stopping tensorboard process...')
|
print('Stopping tensorboard process...')
|
||||||
tensorboard_proc.kill()
|
tensorboard_proc.kill()
|
||||||
print('...process stopped')
|
print('...process stopped')
|
||||||
|
|
||||||
|
|
||||||
def gradio_tensorboard():
|
def gradio_tensorboard():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
button_start_tensorboard = gr.Button('Start tensorboard')
|
button_start_tensorboard = gr.Button('Start tensorboard')
|
||||||
button_stop_tensorboard = gr.Button('Stop tensorboard')
|
button_stop_tensorboard = gr.Button('Stop tensorboard')
|
||||||
|
|
||||||
return(button_start_tensorboard, button_stop_tensorboard)
|
return (button_start_tensorboard, button_stop_tensorboard)
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -50,16 +50,19 @@ def UI(**kwargs):
|
|||||||
utilities_tab()
|
utilities_tab()
|
||||||
|
|
||||||
# Show the interface
|
# Show the interface
|
||||||
launch_kwargs={}
|
launch_kwargs = {}
|
||||||
if not kwargs.get('username', None) == '':
|
if not kwargs.get('username', None) == '':
|
||||||
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
|
launch_kwargs['auth'] = (
|
||||||
|
kwargs.get('username', None),
|
||||||
|
kwargs.get('password', None),
|
||||||
|
)
|
||||||
if kwargs.get('server_port', 0) > 0:
|
if kwargs.get('server_port', 0) > 0:
|
||||||
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
|
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
|
||||||
if kwargs.get('inbrowser', False):
|
if kwargs.get('inbrowser', False):
|
||||||
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
|
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
|
||||||
print(launch_kwargs)
|
print(launch_kwargs)
|
||||||
interface.launch(**launch_kwargs)
|
interface.launch(**launch_kwargs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||||
@ -71,10 +74,20 @@ if __name__ == '__main__':
|
|||||||
'--password', type=str, default='', help='Password for authentication'
|
'--password', type=str, default='', help='Password for authentication'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--server_port', type=int, default=0, help='Port to run the server listener on'
|
'--server_port',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='Port to run the server listener on',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--inbrowser', action='store_true', help='Open in browser'
|
||||||
)
|
)
|
||||||
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
|
UI(
|
||||||
|
username=args.username,
|
||||||
|
password=args.password,
|
||||||
|
inbrowser=args.inbrowser,
|
||||||
|
server_port=args.server_port,
|
||||||
|
)
|
||||||
|
1
locon
Submodule
1
locon
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 143b7b1e33a4253b13f45526de41df748b97e585
|
114
lora_gui.py
114
lora_gui.py
@ -46,6 +46,20 @@ folder_symbol = '\U0001f4c2' # 📂
|
|||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
save_style_symbol = '\U0001f4be' # 💾
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
document_symbol = '\U0001F4C4' # 📄
|
document_symbol = '\U0001F4C4' # 📄
|
||||||
|
path_of_this_folder = os.getcwd()
|
||||||
|
|
||||||
|
def getlocon(existance):
|
||||||
|
now_path = os.getcwd()
|
||||||
|
if existance:
|
||||||
|
print('Checking LoCon script version...')
|
||||||
|
os.chdir(os.path.join(path_of_this_folder, 'locon'))
|
||||||
|
os.system('git pull')
|
||||||
|
os.chdir(now_path)
|
||||||
|
else:
|
||||||
|
os.chdir(path_of_this_folder)
|
||||||
|
os.system('git clone https://github.com/KohakuBlueleaf/LoCon.git locon')
|
||||||
|
os.chdir(now_path)
|
||||||
|
|
||||||
|
|
||||||
def save_configuration(
|
def save_configuration(
|
||||||
save_as,
|
save_as,
|
||||||
@ -105,9 +119,11 @@ def save_configuration(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,noise_offset,
|
optimizer_args,noise_offset,
|
||||||
|
locon=0, conv_dim=0, conv_alpha=0,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -211,9 +227,11 @@ def open_configuration(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,noise_offset,
|
optimizer_args,noise_offset,
|
||||||
|
locon=0, conv_dim=0, conv_alpha=0,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -237,7 +255,7 @@ def open_configuration(
|
|||||||
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
||||||
if not key in ['file_path']:
|
if not key in ['file_path']:
|
||||||
values.append(my_data.get(key, value))
|
values.append(my_data.get(key, value))
|
||||||
|
|
||||||
return tuple(values)
|
return tuple(values)
|
||||||
|
|
||||||
|
|
||||||
@ -297,9 +315,11 @@ def train_model(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,noise_offset,
|
optimizer_args,noise_offset,
|
||||||
|
locon, conv_dim, conv_alpha,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -435,7 +455,12 @@ def train_model(
|
|||||||
run_cmd += f' --save_model_as={save_model_as}'
|
run_cmd += f' --save_model_as={save_model_as}'
|
||||||
if not float(prior_loss_weight) == 1.0:
|
if not float(prior_loss_weight) == 1.0:
|
||||||
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
|
||||||
run_cmd += f' --network_module=networks.lora'
|
if locon:
|
||||||
|
getlocon(os.path.exists(os.path.join(path_of_this_folder, 'locon')))
|
||||||
|
run_cmd += f' --network_module=locon.locon.locon_kohya'
|
||||||
|
run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"'
|
||||||
|
else:
|
||||||
|
run_cmd += f' --network_module=networks.lora'
|
||||||
|
|
||||||
if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
|
if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
|
||||||
if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0):
|
if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0):
|
||||||
@ -653,19 +678,19 @@ def lora_tab(
|
|||||||
placeholder='Optional',
|
placeholder='Optional',
|
||||||
)
|
)
|
||||||
network_dim = gr.Slider(
|
network_dim = gr.Slider(
|
||||||
minimum=4,
|
minimum=1,
|
||||||
maximum=1024,
|
maximum=1024,
|
||||||
label='Network Rank (Dimension)',
|
label='Network Rank (Dimension)',
|
||||||
value=8,
|
value=8,
|
||||||
step=4,
|
step=1,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
network_alpha = gr.Slider(
|
network_alpha = gr.Slider(
|
||||||
minimum=4,
|
minimum=1,
|
||||||
maximum=1024,
|
maximum=1024,
|
||||||
label='Network Alpha',
|
label='Network Alpha',
|
||||||
value=1,
|
value=1,
|
||||||
step=4,
|
step=1,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -683,6 +708,22 @@ def lora_tab(
|
|||||||
)
|
)
|
||||||
enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
|
enable_bucket = gr.Checkbox(label='Enable buckets', value=True)
|
||||||
with gr.Accordion('Advanced Configuration', open=False):
|
with gr.Accordion('Advanced Configuration', open=False):
|
||||||
|
with gr.Row():
|
||||||
|
locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False)
|
||||||
|
conv_dim = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=512,
|
||||||
|
value=1,
|
||||||
|
step=1,
|
||||||
|
label='LoCon Convolution Rank (Dimension)',
|
||||||
|
)
|
||||||
|
conv_alpha = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=512,
|
||||||
|
value=1,
|
||||||
|
step=1,
|
||||||
|
label='LoCon Convolution Alpha',
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
no_token_padding = gr.Checkbox(
|
no_token_padding = gr.Checkbox(
|
||||||
label='No token padding', value=False
|
label='No token padding', value=False
|
||||||
@ -723,14 +764,16 @@ def lora_tab(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,noise_offset,
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_dropout_rate,
|
||||||
|
noise_offset,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
inputs=[color_aug],
|
inputs=[color_aug],
|
||||||
outputs=[cache_latents],
|
outputs=[cache_latents],
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer.change(
|
optimizer.change(
|
||||||
set_legacy_8bitadam,
|
set_legacy_8bitadam,
|
||||||
inputs=[optimizer, use_8bit_adam],
|
inputs=[optimizer, use_8bit_adam],
|
||||||
@ -753,15 +796,15 @@ def lora_tab(
|
|||||||
gradio_verify_lora_tab()
|
gradio_verify_lora_tab()
|
||||||
|
|
||||||
button_run = gr.Button('Train model', variant='primary')
|
button_run = gr.Button('Train model', variant='primary')
|
||||||
|
|
||||||
# Setup gradio tensorboard buttons
|
# Setup gradio tensorboard buttons
|
||||||
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
|
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
|
||||||
|
|
||||||
button_start_tensorboard.click(
|
button_start_tensorboard.click(
|
||||||
start_tensorboard,
|
start_tensorboard,
|
||||||
inputs=logging_dir,
|
inputs=logging_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
button_stop_tensorboard.click(
|
button_stop_tensorboard.click(
|
||||||
stop_tensorboard,
|
stop_tensorboard,
|
||||||
)
|
)
|
||||||
@ -822,9 +865,11 @@ def lora_tab(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_dropout_rate,
|
||||||
optimizer,
|
optimizer,
|
||||||
optimizer_args,noise_offset,
|
optimizer_args,noise_offset,
|
||||||
|
locon, conv_dim, conv_alpha,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
@ -886,16 +931,19 @@ def UI(**kwargs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Show the interface
|
# Show the interface
|
||||||
launch_kwargs={}
|
launch_kwargs = {}
|
||||||
if not kwargs.get('username', None) == '':
|
if not kwargs.get('username', None) == '':
|
||||||
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
|
launch_kwargs['auth'] = (
|
||||||
|
kwargs.get('username', None),
|
||||||
|
kwargs.get('password', None),
|
||||||
|
)
|
||||||
if kwargs.get('server_port', 0) > 0:
|
if kwargs.get('server_port', 0) > 0:
|
||||||
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
|
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
|
||||||
if kwargs.get('inbrowser', False):
|
if kwargs.get('inbrowser', False):
|
||||||
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
|
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
|
||||||
print(launch_kwargs)
|
print(launch_kwargs)
|
||||||
interface.launch(**launch_kwargs)
|
interface.launch(**launch_kwargs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||||
@ -907,10 +955,20 @@ if __name__ == '__main__':
|
|||||||
'--password', type=str, default='', help='Password for authentication'
|
'--password', type=str, default='', help='Password for authentication'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--server_port', type=int, default=0, help='Port to run the server listener on'
|
'--server_port',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='Port to run the server listener on',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--inbrowser', action='store_true', help='Open in browser'
|
||||||
)
|
)
|
||||||
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
|
UI(
|
||||||
|
username=args.username,
|
||||||
|
password=args.password,
|
||||||
|
inbrowser=args.inbrowser,
|
||||||
|
server_port=args.server_port,
|
||||||
|
)
|
||||||
|
194
networks/extract_lora_from_models copy.py
Normal file
194
networks/extract_lora_from_models copy.py
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
# extract approximating LoRA by svd from two SD models
|
||||||
|
# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
||||||
|
# Thanks to cloneofsimo!
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
import library.model_util as model_util
|
||||||
|
import lora
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
CLAMP_QUANTILE = 1 # 0.99
|
||||||
|
MIN_DIFF = 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def save_to_file(file_name, model, state_dict, dtype):
|
||||||
|
if dtype is not None:
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
if type(state_dict[key]) == torch.Tensor:
|
||||||
|
state_dict[key] = state_dict[key].to(dtype)
|
||||||
|
|
||||||
|
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||||
|
save_file(model, file_name)
|
||||||
|
else:
|
||||||
|
torch.save(model, file_name)
|
||||||
|
|
||||||
|
|
||||||
|
def svd(args):
|
||||||
|
def str_to_dtype(p):
|
||||||
|
if p == 'float':
|
||||||
|
return torch.float
|
||||||
|
if p == 'fp16':
|
||||||
|
return torch.float16
|
||||||
|
if p == 'bf16':
|
||||||
|
return torch.bfloat16
|
||||||
|
return None
|
||||||
|
|
||||||
|
save_dtype = str_to_dtype(args.save_precision)
|
||||||
|
|
||||||
|
print(f"loading SD model : {args.model_org}")
|
||||||
|
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
|
||||||
|
print(f"loading SD model : {args.model_tuned}")
|
||||||
|
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
||||||
|
|
||||||
|
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||||
|
lora_network_o = lora.create_network(1.0, args.dim, args.dim * 1.5, None, text_encoder_o, unet_o)
|
||||||
|
lora_network_t = lora.create_network(1.0, args.dim, args.dim * 1.5, None, text_encoder_t, unet_t)
|
||||||
|
assert len(lora_network_o.text_encoder_loras) == len(
|
||||||
|
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
||||||
|
|
||||||
|
# get diffs
|
||||||
|
diffs = {}
|
||||||
|
text_encoder_different = False
|
||||||
|
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
|
||||||
|
lora_name = lora_o.lora_name
|
||||||
|
module_o = lora_o.org_module
|
||||||
|
module_t = lora_t.org_module
|
||||||
|
diff = module_t.weight - module_o.weight
|
||||||
|
|
||||||
|
# Text Encoder might be same
|
||||||
|
if torch.max(torch.abs(diff)) > MIN_DIFF:
|
||||||
|
text_encoder_different = True
|
||||||
|
|
||||||
|
diff = diff.float()
|
||||||
|
diffs[lora_name] = diff
|
||||||
|
|
||||||
|
if not text_encoder_different:
|
||||||
|
print("Text encoder is same. Extract U-Net only.")
|
||||||
|
lora_network_o.text_encoder_loras = []
|
||||||
|
diffs = {}
|
||||||
|
|
||||||
|
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
|
||||||
|
lora_name = lora_o.lora_name
|
||||||
|
module_o = lora_o.org_module
|
||||||
|
module_t = lora_t.org_module
|
||||||
|
diff = module_t.weight - module_o.weight
|
||||||
|
diff = diff.float()
|
||||||
|
|
||||||
|
if args.device:
|
||||||
|
diff = diff.to(args.device)
|
||||||
|
|
||||||
|
diffs[lora_name] = diff
|
||||||
|
|
||||||
|
# make LoRA with SVD
|
||||||
|
print("calculating by SVD")
|
||||||
|
rank = args.dim
|
||||||
|
lora_weights = {}
|
||||||
|
with torch.no_grad():
|
||||||
|
for lora_name, mat in tqdm(list(diffs.items())):
|
||||||
|
conv2d = (len(mat.size()) == 4)
|
||||||
|
if conv2d:
|
||||||
|
mat = mat.squeeze()
|
||||||
|
|
||||||
|
U, S, Vt = torch.linalg.svd(mat)
|
||||||
|
|
||||||
|
U = U[:, :rank]
|
||||||
|
S = S[:rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
|
||||||
|
Vt = Vt[:rank, :]
|
||||||
|
|
||||||
|
lora_weights[lora_name] = (U, Vt)
|
||||||
|
|
||||||
|
# # make LoRA with svd
|
||||||
|
# print("calculating by svd")
|
||||||
|
# rank = args.dim
|
||||||
|
# lora_weights = {}
|
||||||
|
# with torch.no_grad():
|
||||||
|
# for lora_name, mat in tqdm(list(diffs.items())):
|
||||||
|
# conv2d = (len(mat.size()) == 4)
|
||||||
|
# if conv2d:
|
||||||
|
# mat = mat.squeeze()
|
||||||
|
|
||||||
|
# U, S, Vh = torch.linalg.svd(mat)
|
||||||
|
|
||||||
|
# U = U[:, :rank]
|
||||||
|
# S = S[:rank]
|
||||||
|
# U = U @ torch.diag(S)
|
||||||
|
|
||||||
|
# Vh = Vh[:rank, :]
|
||||||
|
|
||||||
|
# # create new tensors directly from the numpy arrays
|
||||||
|
# U = torch.as_tensor(U)
|
||||||
|
# Vh = torch.as_tensor(Vh)
|
||||||
|
|
||||||
|
# # dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
|
# # hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||||
|
# # low_val = -hi_val
|
||||||
|
|
||||||
|
# # U = U.clamp(low_val, hi_val)
|
||||||
|
# # Vh = Vh.clamp(low_val, hi_val)
|
||||||
|
|
||||||
|
# # # soft thresholding
|
||||||
|
# # alpha = S[-1] / 1000.0 # adjust this parameter as needed
|
||||||
|
# # U = torch.sign(U) * torch.nn.functional.relu(torch.abs(U) - alpha)
|
||||||
|
# # Vh = torch.sign(Vh) * torch.nn.functional.relu(torch.abs(Vh) - alpha)
|
||||||
|
|
||||||
|
# lora_weights[lora_name] = (U, Vh)
|
||||||
|
|
||||||
|
# make state dict for LoRA
|
||||||
|
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
|
||||||
|
lora_sd = lora_network_o.state_dict()
|
||||||
|
print(f"LoRA has {len(lora_sd)} weights.")
|
||||||
|
|
||||||
|
for key in list(lora_sd.keys()):
|
||||||
|
if "alpha" in key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_name = key.split('.')[0]
|
||||||
|
i = 0 if "lora_up" in key else 1
|
||||||
|
|
||||||
|
weights = lora_weights[lora_name][i]
|
||||||
|
# print(key, i, weights.size(), lora_sd[key].size())
|
||||||
|
if len(lora_sd[key].size()) == 4:
|
||||||
|
weights = weights.unsqueeze(2).unsqueeze(3)
|
||||||
|
|
||||||
|
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
|
||||||
|
lora_sd[key] = weights
|
||||||
|
|
||||||
|
# load state dict to LoRA and save it
|
||||||
|
info = lora_network_o.load_state_dict(lora_sd)
|
||||||
|
print(f"Loading extracted LoRA weights: {info}")
|
||||||
|
|
||||||
|
dir_name = os.path.dirname(args.save_to)
|
||||||
|
if dir_name and not os.path.exists(dir_name):
|
||||||
|
os.makedirs(dir_name, exist_ok=True)
|
||||||
|
|
||||||
|
# minimum metadata
|
||||||
|
metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim * 1.5)}
|
||||||
|
|
||||||
|
lora_network_o.save_weights(args.save_to, save_dtype, metadata)
|
||||||
|
print(f"LoRA weights are saved to: {args.save_to}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--v2", action='store_true',
|
||||||
|
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||||
|
parser.add_argument("--save_precision", type=str, default=None,
|
||||||
|
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat")
|
||||||
|
parser.add_argument("--model_org", type=str, default=None,
|
||||||
|
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors")
|
||||||
|
parser.add_argument("--model_tuned", type=str, default=None,
|
||||||
|
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors")
|
||||||
|
parser.add_argument("--save_to", type=str, default=None,
|
||||||
|
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||||
|
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
||||||
|
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
svd(args)
|
@ -38,10 +38,11 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
|||||||
torch.save(model, file_name)
|
torch.save(model, file_name)
|
||||||
|
|
||||||
|
|
||||||
def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
def resize_lora_model(lora_sd, new_rank, save_dtype, device, sv_ratio, verbose):
|
||||||
network_alpha = None
|
network_alpha = None
|
||||||
network_dim = None
|
network_dim = None
|
||||||
verbose_str = "\n"
|
verbose_str = "\n"
|
||||||
|
ratio_flag = False
|
||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
@ -57,9 +58,12 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
|||||||
network_alpha = network_dim
|
network_alpha = network_dim
|
||||||
|
|
||||||
scale = network_alpha/network_dim
|
scale = network_alpha/network_dim
|
||||||
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
if not sv_ratio:
|
||||||
|
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
||||||
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
|
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new dim: {new_rank}, new alpha: {new_alpha}")
|
||||||
|
else:
|
||||||
|
print(f"Dynamically determining new alphas and dims based off sv ratio: {sv_ratio}")
|
||||||
|
ratio_flag = True
|
||||||
|
|
||||||
lora_down_weight = None
|
lora_down_weight = None
|
||||||
lora_up_weight = None
|
lora_up_weight = None
|
||||||
@ -97,11 +101,24 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
|||||||
|
|
||||||
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
||||||
|
|
||||||
|
if ratio_flag:
|
||||||
|
# Calculate new dim and alpha for dynamic sizing
|
||||||
|
max_sv = S[0]
|
||||||
|
min_sv = max_sv/sv_ratio
|
||||||
|
new_rank = torch.sum(S > min_sv).item()
|
||||||
|
new_rank = max(new_rank, 1)
|
||||||
|
new_alpha = float(scale*new_rank)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
s_sum = torch.sum(torch.abs(S))
|
s_sum = torch.sum(torch.abs(S))
|
||||||
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||||
verbose_str+=f"{block_down_name:76} | "
|
verbose_str+=f"{block_down_name:75} | "
|
||||||
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n"
|
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}"
|
||||||
|
|
||||||
|
if verbose and ratio_flag:
|
||||||
|
verbose_str+=f", dynamic| dim: {new_rank}, alpha: {new_alpha}\n"
|
||||||
|
else:
|
||||||
|
verbose_str+=f"\n"
|
||||||
|
|
||||||
U = U[:, :new_rank]
|
U = U[:, :new_rank]
|
||||||
S = S[:new_rank]
|
S = S[:new_rank]
|
||||||
@ -160,16 +177,21 @@ def resize(args):
|
|||||||
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||||
|
|
||||||
print("resizing rank...")
|
print("resizing rank...")
|
||||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
|
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.sv_ratio, args.verbose)
|
||||||
|
|
||||||
# update metadata
|
# update metadata
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
comment = metadata.get("ss_training_comment", "")
|
comment = metadata.get("ss_training_comment", "")
|
||||||
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
if not args.sv_ratio:
|
||||||
metadata["ss_network_dim"] = str(args.new_rank)
|
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
||||||
metadata["ss_network_alpha"] = str(new_alpha)
|
metadata["ss_network_dim"] = str(args.new_rank)
|
||||||
|
metadata["ss_network_alpha"] = str(new_alpha)
|
||||||
|
else:
|
||||||
|
metadata["ss_training_comment"] = f"Dynamic resize from {old_dim} with ratio {args.sv_ratio}; {comment}"
|
||||||
|
metadata["ss_network_dim"] = 'Dynamic'
|
||||||
|
metadata["ss_network_alpha"] = 'Dynamic'
|
||||||
|
|
||||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||||
metadata["sshs_model_hash"] = model_hash
|
metadata["sshs_model_hash"] = model_hash
|
||||||
@ -193,6 +215,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||||
parser.add_argument("--verbose", action="store_true",
|
parser.add_argument("--verbose", action="store_true",
|
||||||
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
||||||
|
parser.add_argument("--sv_ratio", type=float, default=None,
|
||||||
|
help="Specify svd ratio for dim calcs. Will override --new_rank")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
resize(args)
|
resize(args)
|
@ -101,8 +101,11 @@ def save_configuration(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs,
|
||||||
optimizer,optimizer_args,noise_offset,
|
caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
|
optimizer_args,
|
||||||
|
noise_offset,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -205,8 +208,11 @@ def open_configuration(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs,
|
||||||
optimizer,optimizer_args,noise_offset,
|
caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
|
optimizer_args,
|
||||||
|
noise_offset,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -288,8 +294,11 @@ def train_model(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs,
|
||||||
optimizer,optimizer_args,noise_offset,
|
caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
|
optimizer_args,
|
||||||
|
noise_offset,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -641,7 +650,8 @@ def ti_tab(
|
|||||||
seed,
|
seed,
|
||||||
caption_extension,
|
caption_extension,
|
||||||
cache_latents,
|
cache_latents,
|
||||||
optimizer,optimizer_args,
|
optimizer,
|
||||||
|
optimizer_args,
|
||||||
) = gradio_training(
|
) = gradio_training(
|
||||||
learning_rate_value='1e-5',
|
learning_rate_value='1e-5',
|
||||||
lr_scheduler_value='cosine',
|
lr_scheduler_value='cosine',
|
||||||
@ -699,7 +709,9 @@ def ti_tab(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,noise_offset,
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_dropout_rate,
|
||||||
|
noise_offset,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -723,15 +735,15 @@ def ti_tab(
|
|||||||
)
|
)
|
||||||
|
|
||||||
button_run = gr.Button('Train model', variant='primary')
|
button_run = gr.Button('Train model', variant='primary')
|
||||||
|
|
||||||
# Setup gradio tensorboard buttons
|
# Setup gradio tensorboard buttons
|
||||||
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
|
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
|
||||||
|
|
||||||
button_start_tensorboard.click(
|
button_start_tensorboard.click(
|
||||||
start_tensorboard,
|
start_tensorboard,
|
||||||
inputs=logging_dir,
|
inputs=logging_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
button_stop_tensorboard.click(
|
button_stop_tensorboard.click(
|
||||||
stop_tensorboard,
|
stop_tensorboard,
|
||||||
)
|
)
|
||||||
@ -791,8 +803,11 @@ def ti_tab(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
caption_dropout_every_n_epochs, caption_dropout_rate,
|
caption_dropout_every_n_epochs,
|
||||||
optimizer,optimizer_args,noise_offset,
|
caption_dropout_rate,
|
||||||
|
optimizer,
|
||||||
|
optimizer_args,
|
||||||
|
noise_offset,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
@ -854,16 +869,19 @@ def UI(**kwargs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Show the interface
|
# Show the interface
|
||||||
launch_kwargs={}
|
launch_kwargs = {}
|
||||||
if not kwargs.get('username', None) == '':
|
if not kwargs.get('username', None) == '':
|
||||||
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
|
launch_kwargs['auth'] = (
|
||||||
|
kwargs.get('username', None),
|
||||||
|
kwargs.get('password', None),
|
||||||
|
)
|
||||||
if kwargs.get('server_port', 0) > 0:
|
if kwargs.get('server_port', 0) > 0:
|
||||||
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
|
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
|
||||||
if kwargs.get('inbrowser', False):
|
if kwargs.get('inbrowser', False):
|
||||||
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
|
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
|
||||||
print(launch_kwargs)
|
print(launch_kwargs)
|
||||||
interface.launch(**launch_kwargs)
|
interface.launch(**launch_kwargs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||||
@ -875,10 +893,20 @@ if __name__ == '__main__':
|
|||||||
'--password', type=str, default='', help='Password for authentication'
|
'--password', type=str, default='', help='Password for authentication'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--server_port', type=int, default=0, help='Port to run the server listener on'
|
'--server_port',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='Port to run the server listener on',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--inbrowser', action='store_true', help='Open in browser'
|
||||||
)
|
)
|
||||||
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
|
UI(
|
||||||
|
username=args.username,
|
||||||
|
password=args.password,
|
||||||
|
inbrowser=args.inbrowser,
|
||||||
|
server_port=args.server_port,
|
||||||
|
)
|
||||||
|
@ -24,7 +24,7 @@ python convert_diffusers20_original_sd.py ..\models\sd.ckpt
|
|||||||
Specify the .ckpt file and the destination folder as arguments.
|
Specify the .ckpt file and the destination folder as arguments.
|
||||||
Model judgment is not possible, so please use the `--v1` option or the `--v2` option depending on the model.
|
Model judgment is not possible, so please use the `--v1` option or the `--v2` option depending on the model.
|
||||||
|
|
||||||
Also, since `.ckpt` does not contain schduler and tokenizer information, you need to copy them from some existing Diffusers model. Please specify with `--reference_model`. You can specify the HuggingFace id or a local model directory.
|
Also, since `.ckpt` does not contain scheduler and tokenizer information, you need to copy them from some existing Diffusers model. Please specify with `--reference_model`. You can specify the HuggingFace id or a local model directory.
|
||||||
|
|
||||||
If you don't have a local model, you can specify "stabilityai/stable-diffusion-2" or "stabilityai/stable-diffusion-2-base" for v2.
|
If you don't have a local model, you can specify "stabilityai/stable-diffusion-2" or "stabilityai/stable-diffusion-2-base" for v2.
|
||||||
For v1.4/1.5, "CompVis/stable-diffusion-v1-4" is fine (v1.4 and v1.5 seem to be the same).
|
For v1.4/1.5, "CompVis/stable-diffusion-v1-4" is fine (v1.4 and v1.5 seem to be the same).
|
||||||
|
Loading…
Reference in New Issue
Block a user