- Add new tool to verify LoRA weights produced by the trainer. Can be found under "Dreambooth LoRA/Tools/Verify LoRA

This commit is contained in:
bmaltais 2023-01-22 11:40:14 -05:00
parent 2ca17f69dd
commit 511361c80b
4 changed files with 111 additions and 1 deletions

View File

@ -114,8 +114,24 @@ Once you have created the LoRA network you can generate images via auto1111 by i
- Re-install python 3.10.x on your system: https://www.python.org/ftp/python/3.10.9/python-3.10.9-amd64.exe
### FileNotFoundError
This is usually related to an installation issue. Make sure you do not have python modules installed locally that could conflict with the ones installed in the venv:
1. Open a new powershell terminal and make sure no venv is active.
2. Run the following commands
```
pip freeze > uninstall.txt
pip uninstall -r uninstall.txt
```
Then redo the installation instruction within the kohya_ss venv.
## Change history
* 2023/01/22 (v20.4.1):
- Add new tool to verify LoRA weights produced by the trainer. Can be found under "Dreambooth LoRA/Tools/Verify LoRA"
* 2023/01/22 (v20.4.0):
- Add support for `network_alpha` under the Training tab and support for `--training_comment` under the Folders tab.
- Add ``--network_alpha`` option to specify ``alpha`` value to prevent underflows for stable training. Thanks to CCRcmcpe!

View File

@ -0,0 +1,91 @@
import gradio as gr
from easygui import msgbox
import subprocess
import os
from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
def verify_lora(
lora_model,
):
# verify for caption_text_input
if lora_model == '':
msgbox('Invalid model A file')
return
# verify if source model exist
if not os.path.isfile(lora_model):
msgbox('The provided model A is not a file')
return
run_cmd = f'.\\venv\Scripts\python.exe "networks\check_lora_weights.py"'
run_cmd += f' {lora_model}'
print(run_cmd)
# Run the command
subprocess.run(run_cmd)
process = subprocess.Popen(run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, error = process.communicate()
return (output.decode(), error.decode())
###
# Gradio UI
###
def gradio_verify_lora_tab():
with gr.Tab('Verify LoRA'):
gr.Markdown(
'This utility can verify a LoRA network to make sure it is properly trained.'
)
lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
with gr.Row():
lora_model = gr.Textbox(
label='LoRA model',
placeholder='Path to the LoRA model to verify',
interactive=True,
)
button_lora_model_file = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
button_lora_model_file.click(
get_file_path,
inputs=[lora_model, lora_ext, lora_ext_name],
outputs=lora_model,
)
verify_button = gr.Button('Verify', variant="primary")
lora_model_verif_output = gr.Textbox(
label='Output',
placeholder='Verification output',
interactive=False,
lines=1,
max_lines=10,
)
lora_model_verif_error = gr.Textbox(
label='Error',
placeholder='Verification error',
interactive=False,
lines=1,
max_lines=10,
)
verify_button.click(
verify_lora,
inputs=[
lora_model,
],
outputs=[lora_model_verif_output, lora_model_verif_error]
)

View File

@ -31,6 +31,7 @@ from library.dreambooth_folder_creation_gui import (
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from library.utilities import utilities_tab
from library.merge_lora_gui import gradio_merge_lora_tab
from library.verify_lora_gui import gradio_verify_lora_tab
from easygui import msgbox
folder_symbol = '\U0001f4c2' # 📂
@ -675,6 +676,8 @@ def lora_tab(
)
gradio_dataset_balancing_tab()
gradio_merge_lora_tab()
gradio_verify_lora_tab()
button_run = gr.Button('Train model')

View File

@ -1,3 +1,3 @@
from setuptools import setup, find_packages
setup(name = "library", version="1.0.1", packages = find_packages())
setup(name = "library", version="1.0.2", packages = find_packages())