From 511361c80b1971a6716fb8ad2cb3142e08e7c118 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sun, 22 Jan 2023 11:40:14 -0500 Subject: [PATCH] - Add new tool to verify LoRA weights produced by the trainer. Can be found under "Dreambooth LoRA/Tools/Verify LoRA --- README.md | 16 +++++++ library/verify_lora_gui.py | 91 ++++++++++++++++++++++++++++++++++++++ lora_gui.py | 3 ++ setup.py | 2 +- 4 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 library/verify_lora_gui.py diff --git a/README.md b/README.md index c074b2f..6148d85 100644 --- a/README.md +++ b/README.md @@ -114,8 +114,24 @@ Once you have created the LoRA network you can generate images via auto1111 by i - Re-install python 3.10.x on your system: https://www.python.org/ftp/python/3.10.9/python-3.10.9-amd64.exe +### FileNotFoundError + +This is usually related to an installation issue. Make sure you do not have python modules installed locally that could conflict with the ones installed in the venv: + +1. Open a new powershell terminal and make sure no venv is active. +2. Run the following commands + +``` +pip freeze > uninstall.txt +pip uninstall -r uninstall.txt +``` + +Then redo the installation instruction within the kohya_ss venv. + ## Change history +* 2023/01/22 (v20.4.1): + - Add new tool to verify LoRA weights produced by the trainer. Can be found under "Dreambooth LoRA/Tools/Verify LoRA" * 2023/01/22 (v20.4.0): - Add support for `network_alpha` under the Training tab and support for `--training_comment` under the Folders tab. - Add ``--network_alpha`` option to specify ``alpha`` value to prevent underflows for stable training. Thanks to CCRcmcpe! diff --git a/library/verify_lora_gui.py b/library/verify_lora_gui.py new file mode 100644 index 0000000..ada20d1 --- /dev/null +++ b/library/verify_lora_gui.py @@ -0,0 +1,91 @@ +import gradio as gr +from easygui import msgbox +import subprocess +import os +from .common_gui import get_saveasfilename_path, get_any_file_path, get_file_path + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +document_symbol = '\U0001F4C4' # 📄 + + +def verify_lora( + lora_model, +): + # verify for caption_text_input + if lora_model == '': + msgbox('Invalid model A file') + return + + # verify if source model exist + if not os.path.isfile(lora_model): + msgbox('The provided model A is not a file') + return + + run_cmd = f'.\\venv\Scripts\python.exe "networks\check_lora_weights.py"' + run_cmd += f' {lora_model}' + + print(run_cmd) + + # Run the command + subprocess.run(run_cmd) + process = subprocess.Popen(run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, error = process.communicate() + + return (output.decode(), error.decode()) + + +### +# Gradio UI +### + + +def gradio_verify_lora_tab(): + with gr.Tab('Verify LoRA'): + gr.Markdown( + 'This utility can verify a LoRA network to make sure it is properly trained.' + ) + + lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False) + lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) + + with gr.Row(): + lora_model = gr.Textbox( + label='LoRA model', + placeholder='Path to the LoRA model to verify', + interactive=True, + ) + button_lora_model_file = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + button_lora_model_file.click( + get_file_path, + inputs=[lora_model, lora_ext, lora_ext_name], + outputs=lora_model, + ) + verify_button = gr.Button('Verify', variant="primary") + + lora_model_verif_output = gr.Textbox( + label='Output', + placeholder='Verification output', + interactive=False, + lines=1, + max_lines=10, + ) + + lora_model_verif_error = gr.Textbox( + label='Error', + placeholder='Verification error', + interactive=False, + lines=1, + max_lines=10, + ) + + verify_button.click( + verify_lora, + inputs=[ + lora_model, + ], + outputs=[lora_model_verif_output, lora_model_verif_error] + ) diff --git a/lora_gui.py b/lora_gui.py index 2a7421a..5c62d08 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -31,6 +31,7 @@ from library.dreambooth_folder_creation_gui import ( from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.utilities import utilities_tab from library.merge_lora_gui import gradio_merge_lora_tab +from library.verify_lora_gui import gradio_verify_lora_tab from easygui import msgbox folder_symbol = '\U0001f4c2' # 📂 @@ -675,6 +676,8 @@ def lora_tab( ) gradio_dataset_balancing_tab() gradio_merge_lora_tab() + gradio_verify_lora_tab() + button_run = gr.Button('Train model') diff --git a/setup.py b/setup.py index 8965557..55fe23b 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,3 @@ from setuptools import setup, find_packages -setup(name = "library", version="1.0.1", packages = find_packages()) \ No newline at end of file +setup(name = "library", version="1.0.2", packages = find_packages()) \ No newline at end of file