From eb95809501068a38f2b6bdb01b6ae5b86ff7ae87 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 9 May 2023 11:25:46 +0300 Subject: [PATCH] rework loras api --- extensions-builtin/Lora/lora.py | 6 ---- extensions-builtin/Lora/scripts/api.py | 31 ------------------- .../Lora/scripts/lora_script.py | 21 ++++++++++++- 3 files changed, 20 insertions(+), 38 deletions(-) delete mode 100644 extensions-builtin/Lora/scripts/api.py diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 05162e41..ba1293df 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -3,7 +3,6 @@ import os import re import torch from typing import Union -import scripts.api as api from modules import shared, devices, sd_models, errors, scripts @@ -449,8 +448,3 @@ available_lora_aliases = {} loaded_loras = [] list_available_loras() -try: - import modules.script_callbacks as script_callbacks - script_callbacks.on_app_started(api.api) -except: - pass \ No newline at end of file diff --git a/extensions-builtin/Lora/scripts/api.py b/extensions-builtin/Lora/scripts/api.py deleted file mode 100644 index f1f2e2fc..00000000 --- a/extensions-builtin/Lora/scripts/api.py +++ /dev/null @@ -1,31 +0,0 @@ -from fastapi import FastAPI -import gradio as gr -import json -import os -import lora - -def get_lora_prompts(path): - directory, filename = os.path.split(path) - name_without_ext = os.path.splitext(filename)[0] - new_filename = name_without_ext + '.civitai.info' - try: - new_path = os.path.join(directory, new_filename) - if os.path.exists(new_path): - with open(new_path, 'r') as f: - data = json.load(f) - trained_words = data.get('trainedWords', []) - if len(trained_words) > 0: - result = ','.join(trained_words) - return result - else: - return '' - else: - return '' - except Exception as e: - return '' - -def api(_: gr.Blocks, app: FastAPI): - @app.get("/sdapi/v1/loras") - async def get_loras(): - return [{"name": name, "path": lora.available_loras[name].filename, "prompt": get_lora_prompts(lora.available_loras[name].filename)} for name in lora.available_loras] - diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index a67b8a69..7db971fd 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -1,12 +1,12 @@ import torch import gradio as gr +from fastapi import FastAPI import lora import extra_networks_lora import ui_extra_networks_lora from modules import script_callbacks, ui_extra_networks, extra_networks, shared - def unload(): torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora @@ -60,3 +60,22 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), { "lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"), })) + + +def create_lora_json(obj: lora.LoraOnDisk): + return { + "name": obj.name, + "alias": obj.alias, + "path": obj.filename, + "metadata": obj.metadata, + } + + +def api_loras(_: gr.Blocks, app: FastAPI): + @app.get("/sdapi/v1/loras") + async def get_loras(): + return [create_lora_json(obj) for obj in lora.available_loras.values()] + + +script_callbacks.on_app_started(api_loras) +