make CLIP interrogator download original text files if the directory does not exist

remove random artist built-in extension (to re-added as a normal extension on demand)
remove artists.csv (but what does it mean????????????????????)
make interrogate buttons show Loading... when you click them
This commit is contained in:
AUTOMATIC 2023-01-21 09:14:27 +03:00
parent 40ff6db532
commit 6d805b669e
9 changed files with 46 additions and 3151 deletions

View File

@ -49,7 +49,6 @@ A browser interface based on Gradio library for Stable Diffusion.
- Running arbitrary python code from UI (must run with --allow-code to enable) - Running arbitrary python code from UI (must run with --allow-code to enable)
- Mouseover hints for most UI elements - Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config - Possible to change defaults/mix/max/step values for UI elements via text config
- Random artist button
- Tiling support, a checkbox to create images that can be tiled like textures - Tiling support, a checkbox to create images that can be tiled like textures
- Progress bar and live image generation preview - Progress bar and live image generation preview
- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image - Negative prompt, an extra text field that allows you to list what you don't want to see in generated image

File diff suppressed because it is too large Load Diff

View File

@ -1,50 +0,0 @@
import random
from modules import script_callbacks, shared
import gradio as gr
art_symbol = '\U0001f3a8' # 🎨
global_prompt = None
related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
def roll_artist(prompt):
allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
return prompt + ", " + artist.name if prompt != '' else artist.name
def add_roll_button(prompt):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
roll.click(
fn=roll_artist,
_js="update_txt2img_tokens",
inputs=[
prompt,
],
outputs=[
prompt,
]
)
def after_component(component, **kwargs):
global global_prompt
elem_id = kwargs.get('elem_id', None)
if elem_id not in related_ids:
return
if elem_id == "txt2img_prompt":
global_prompt = component
elif elem_id == "txt2img_clear_prompt":
add_roll_button(global_prompt)
elif elem_id == "img2img_prompt":
global_prompt = component
elif elem_id == "img2img_clear_prompt":
add_roll_button(global_prompt)
script_callbacks.on_after_component(after_component)

View File

@ -14,7 +14,6 @@ titles = {
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result", "Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time", "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
"\u{1f3a8}": "Add a random artist to the prompt.",
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory", "\u{1f4c2}": "Open images output directory",
"\u{1f4be}": "Save style", "\u{1f4be}": "Save style",

View File

@ -126,8 +126,6 @@ class Api:
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
@ -390,12 +388,6 @@ class Api:
return styleList return styleList
def get_artists_categories(self):
return shared.artist_db.cats
def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
def get_embeddings(self): def get_embeddings(self):
db = sd_hijack.model_hijack.embedding_db db = sd_hijack.model_hijack.embedding_db

View File

@ -1,25 +0,0 @@
import os.path
import csv
from collections import namedtuple
Artist = namedtuple("Artist", ['name', 'weight', 'category'])
class ArtistsDatabase:
def __init__(self, filename):
self.cats = set()
self.artists = []
if not os.path.exists(filename):
return
with open(filename, "r", newline='', encoding="utf8") as file:
reader = csv.DictReader(file)
for row in reader:
artist = Artist(row["artist"], float(row["score"]), row["category"])
self.artists.append(artist)
self.cats.add(artist.category)
def categories(self):
return sorted(self.cats)

View File

@ -5,12 +5,13 @@ from collections import namedtuple
import re import re
import torch import torch
import torch.hub
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared import modules.shared as shared
from modules import devices, paths, lowvram, modelloader from modules import devices, paths, lowvram, modelloader, errors
blip_image_eval_size = 384 blip_image_eval_size = 384
clip_model_name = 'ViT-L/14' clip_model_name = 'ViT-L/14'
@ -20,27 +21,59 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.") re_topn = re.compile(r"\.top(\d+)\.")
def download_default_clip_interrogate_categories(content_dir):
print("Downloading CLIP categories...")
tmpdir = content_dir + "_tmp"
try:
os.makedirs(tmpdir)
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt"))
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
os.rename(tmpdir, content_dir)
except Exception as e:
errors.display(e, "downloading default CLIP interrogate categories")
finally:
if os.path.exists(tmpdir):
os.remove(tmpdir)
class InterrogateModels: class InterrogateModels:
blip_model = None blip_model = None
clip_model = None clip_model = None
clip_preprocess = None clip_preprocess = None
categories = None
dtype = None dtype = None
running_on_cpu = None running_on_cpu = None
def __init__(self, content_dir): def __init__(self, content_dir):
self.categories = [] self.loaded_categories = None
self.content_dir = content_dir
self.running_on_cpu = devices.device_interrogate == torch.device("cpu") self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
if os.path.exists(content_dir): def categories(self):
for filename in os.listdir(content_dir): if self.loaded_categories is not None:
return self.loaded_categories
self.loaded_categories = []
if not os.path.exists(self.content_dir):
download_default_clip_interrogate_categories(self.content_dir)
if os.path.exists(self.content_dir):
for filename in os.listdir(self.content_dir):
m = re_topn.search(filename) m = re_topn.search(filename)
topn = 1 if m is None else int(m.group(1)) topn = 1 if m is None else int(m.group(1))
with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file: with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()] lines = [x.strip() for x in file.readlines()]
self.categories.append(Category(name=filename, topn=topn, items=lines)) self.loaded_categories.append(Category(name=filename, topn=topn, items=lines))
return self.loaded_categories
def load_blip_model(self): def load_blip_model(self):
import models.blip import models.blip
@ -139,7 +172,6 @@ class InterrogateModels:
shared.state.begin() shared.state.begin()
shared.state.job = 'interrogate' shared.state.job = 'interrogate'
try: try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu() lowvram.send_everything_to_cpu()
devices.torch_gc() devices.torch_gc()
@ -159,12 +191,7 @@ class InterrogateModels:
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
if shared.opts.interrogate_use_builtin_artists: for name, topn, items in self.categories():
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
res += ", " + artist[0]
for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn) matches = self.rank(image_features, items, top_count=topn)
for match, score in matches: for match, score in matches:
if shared.opts.interrogate_return_ranks: if shared.opts.interrogate_return_ranks:

View File

@ -9,7 +9,6 @@ from PIL import Image
import gradio as gr import gradio as gr
import tqdm import tqdm
import modules.artists
import modules.interrogate import modules.interrogate
import modules.memmon import modules.memmon
import modules.styles import modules.styles
@ -254,8 +253,6 @@ class State:
state = State() state = State()
state.server_start = time.time() state.server_start = time.time()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
styles_filename = cmd_opts.styles_file styles_filename = cmd_opts.styles_file
prompt_styles = modules.styles.StyleDatabase(styles_filename) prompt_styles = modules.styles.StyleDatabase(styles_filename)
@ -408,7 +405,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
})) }))
options_templates.update(options_section(('compatibility', "Compatibility"), { options_templates.update(options_section(('compatibility', "Compatibility"), {
@ -419,7 +415,6 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
options_templates.update(options_section(('interrogate', "Interrogate Options"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."), "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),

View File

@ -228,17 +228,17 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
left, _ = os.path.splitext(filename) left, _ = os.path.splitext(filename)
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a')) print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))
return [gr_show(True), None] return [gr.update(), None]
def interrogate(image): def interrogate(image):
prompt = shared.interrogator.interrogate(image.convert("RGB")) prompt = shared.interrogator.interrogate(image.convert("RGB"))
return gr_show(True) if prompt is None else prompt return gr.update() if prompt is None else prompt
def interrogate_deepbooru(image): def interrogate_deepbooru(image):
prompt = deepbooru.model.tag(image) prompt = deepbooru.model.tag(image)
return gr_show(True) if prompt is None else prompt return gr.update() if prompt is None else prompt
def create_seed_inputs(target_interface): def create_seed_inputs(target_interface):
@ -1039,7 +1039,6 @@ def create_ui():
init_img_inpaint, init_img_inpaint,
], ],
outputs=[img2img_prompt, dummy_component], outputs=[img2img_prompt, dummy_component],
show_progress=False,
) )
img2img_prompt.submit(**img2img_args) img2img_prompt.submit(**img2img_args)