Merge branch 'AUTOMATIC1111:master' into master
This commit is contained in:
commit
0186db178e
@ -81,6 +81,9 @@ titles = {
|
||||
|
||||
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
|
||||
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
|
||||
|
||||
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
|
||||
"Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.",
|
||||
}
|
||||
|
||||
|
||||
|
@ -2,33 +2,44 @@ import os.path
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
import multiprocessing
|
||||
import time
|
||||
import re
|
||||
|
||||
re_special = re.compile(r'([\\()])')
|
||||
|
||||
def get_deepbooru_tags(pil_image):
|
||||
"""
|
||||
This method is for running only one image at a time for simple use. Used to the img2img interrogate.
|
||||
"""
|
||||
from modules import shared # prevents circular reference
|
||||
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, shared.opts.deepbooru_sort_alpha)
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
shared.deepbooru_process_queue.put(pil_image)
|
||||
while shared.deepbooru_process_return["value"] == -1:
|
||||
time.sleep(0.2)
|
||||
tags = shared.deepbooru_process_return["value"]
|
||||
release_process()
|
||||
return tags
|
||||
|
||||
try:
|
||||
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
|
||||
return get_tags_from_process(pil_image)
|
||||
finally:
|
||||
release_process()
|
||||
|
||||
|
||||
def deepbooru_process(queue, deepbooru_process_return, threshold, alpha_sort):
|
||||
def create_deepbooru_opts():
|
||||
from modules import shared
|
||||
|
||||
return {
|
||||
"use_spaces": shared.opts.deepbooru_use_spaces,
|
||||
"use_escape": shared.opts.deepbooru_escape,
|
||||
"alpha_sort": shared.opts.deepbooru_sort_alpha,
|
||||
}
|
||||
|
||||
|
||||
def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
|
||||
model, tags = get_deepbooru_tags_model()
|
||||
while True: # while process is running, keep monitoring queue for new image
|
||||
pil_image = queue.get()
|
||||
if pil_image == "QUIT":
|
||||
break
|
||||
else:
|
||||
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort)
|
||||
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
|
||||
|
||||
|
||||
def create_deepbooru_process(threshold, alpha_sort):
|
||||
def create_deepbooru_process(threshold, deepbooru_opts):
|
||||
"""
|
||||
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
|
||||
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
|
||||
@ -41,10 +52,23 @@ def create_deepbooru_process(threshold, alpha_sort):
|
||||
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
|
||||
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, alpha_sort))
|
||||
shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
|
||||
shared.deepbooru_process.start()
|
||||
|
||||
|
||||
def get_tags_from_process(image):
|
||||
from modules import shared
|
||||
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
shared.deepbooru_process_queue.put(image)
|
||||
while shared.deepbooru_process_return["value"] == -1:
|
||||
time.sleep(0.2)
|
||||
caption = shared.deepbooru_process_return["value"]
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
|
||||
return caption
|
||||
|
||||
|
||||
def release_process():
|
||||
"""
|
||||
Stops the deepbooru process to return used memory
|
||||
@ -81,10 +105,15 @@ def get_deepbooru_tags_model():
|
||||
return model, tags
|
||||
|
||||
|
||||
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort):
|
||||
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
|
||||
import deepdanbooru as dd
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
alpha_sort = deepbooru_opts['alpha_sort']
|
||||
use_spaces = deepbooru_opts['use_spaces']
|
||||
use_escape = deepbooru_opts['use_escape']
|
||||
|
||||
width = model.input_shape[2]
|
||||
height = model.input_shape[1]
|
||||
image = np.array(pil_image)
|
||||
@ -129,4 +158,12 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort)
|
||||
|
||||
print('\n'.join(sorted(result_tags_print, reverse=True)))
|
||||
|
||||
return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ')
|
||||
tags_text = ', '.join(result_tags_out)
|
||||
|
||||
if use_spaces:
|
||||
tags_text = tags_text.replace('_', ' ')
|
||||
|
||||
if use_escape:
|
||||
tags_text = re.sub(re_special, r'\\\1', tags_text)
|
||||
|
||||
return tags_text.replace(':', ' ')
|
||||
|
@ -14,7 +14,7 @@ import torch
|
||||
from torch import einsum
|
||||
from einops import rearrange, repeat
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnSchedule
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
@ -223,31 +223,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||
if ititial_step > steps:
|
||||
return hypernetwork, filename
|
||||
|
||||
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
|
||||
(learn_rate, end_step) = next(schedules)
|
||||
print(f'Training at rate of {learn_rate} until step {end_step}')
|
||||
|
||||
optimizer = torch.optim.AdamW(weights, lr=learn_rate)
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||
for i, (x, text, cond) in pbar:
|
||||
for i, entry in pbar:
|
||||
hypernetwork.step = i + ititial_step
|
||||
|
||||
if hypernetwork.step > end_step:
|
||||
try:
|
||||
(learn_rate, end_step) = next(schedules)
|
||||
except Exception:
|
||||
break
|
||||
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
|
||||
for pg in optimizer.param_groups:
|
||||
pg['lr'] = learn_rate
|
||||
scheduler.apply(optimizer, hypernetwork.step)
|
||||
if scheduler.finished:
|
||||
break
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
cond = cond.to(devices.device)
|
||||
x = x.to(devices.device)
|
||||
cond = entry.cond.to(devices.device)
|
||||
x = entry.latent.to(devices.device)
|
||||
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
|
||||
del x
|
||||
del cond
|
||||
@ -267,7 +259,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
||||
|
||||
preview_text = text if preview_image_prompt == "" else preview_image_prompt
|
||||
preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
|
||||
|
||||
optimizer.zero_grad()
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
@ -282,16 +274,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||
)
|
||||
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0]
|
||||
image = processed.images[0] if len(processed.images)>0 else None
|
||||
|
||||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
shared.state.current_image = image
|
||||
image.save(last_saved_image)
|
||||
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
image.save(last_saved_image)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
shared.state.job_no = hypernetwork.step
|
||||
|
||||
@ -299,7 +291,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||
<p>
|
||||
Loss: {losses.mean():.7f}<br/>
|
||||
Step: {hypernetwork.step}<br/>
|
||||
Last prompt: {html.escape(text)}<br/>
|
||||
Last prompt: {html.escape(entry.cond_text)}<br/>
|
||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
|
@ -231,6 +231,9 @@ options_templates.update(options_section(('system', "System"), {
|
||||
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
|
||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||
"training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
@ -257,6 +260,8 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
||||
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
|
||||
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
|
||||
"deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
|
@ -11,11 +11,21 @@ import tqdm
|
||||
from modules import devices, shared
|
||||
import re
|
||||
|
||||
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
|
||||
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
||||
|
||||
|
||||
class DatasetEntry:
|
||||
def __init__(self, filename=None, latent=None, filename_text=None):
|
||||
self.filename = filename
|
||||
self.latent = latent
|
||||
self.filename_text = filename_text
|
||||
self.cond = None
|
||||
self.cond_text = None
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
|
||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
@ -42,9 +52,18 @@ class PersonalizedBase(Dataset):
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
text_filename = os.path.splitext(path)[0] + ".txt"
|
||||
filename = os.path.basename(path)
|
||||
filename_tokens = os.path.splitext(filename)[0]
|
||||
filename_tokens = re_tag.findall(filename_tokens)
|
||||
|
||||
if os.path.exists(text_filename):
|
||||
with open(text_filename, "r", encoding="utf8") as file:
|
||||
filename_text = file.read()
|
||||
else:
|
||||
filename_text = os.path.splitext(filename)[0]
|
||||
filename_text = re.sub(re_numbers_at_start, '', filename_text)
|
||||
if re_word:
|
||||
tokens = re_word.findall(filename_text)
|
||||
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
||||
|
||||
npimage = np.array(image).astype(np.uint8)
|
||||
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||
@ -55,13 +74,13 @@ class PersonalizedBase(Dataset):
|
||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
||||
init_latent = init_latent.to(devices.cpu)
|
||||
|
||||
if include_cond:
|
||||
text = self.create_text(filename_tokens)
|
||||
cond = cond_model([text]).to(devices.cpu)
|
||||
else:
|
||||
cond = None
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
|
||||
|
||||
self.dataset.append((init_latent, filename_tokens, cond))
|
||||
if include_cond:
|
||||
entry.cond_text = self.create_text(filename_text)
|
||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
|
||||
|
||||
self.dataset.append(entry)
|
||||
|
||||
self.length = len(self.dataset) * repeats
|
||||
|
||||
@ -72,10 +91,10 @@ class PersonalizedBase(Dataset):
|
||||
def shuffle(self):
|
||||
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
||||
|
||||
def create_text(self, filename_tokens):
|
||||
def create_text(self, filename_text):
|
||||
text = random.choice(self.lines)
|
||||
text = text.replace("[name]", self.placeholder_token)
|
||||
text = text.replace("[filewords]", ' '.join(filename_tokens))
|
||||
text = text.replace("[filewords]", filename_text)
|
||||
return text
|
||||
|
||||
def __len__(self):
|
||||
@ -86,7 +105,9 @@ class PersonalizedBase(Dataset):
|
||||
self.shuffle()
|
||||
|
||||
index = self.indexes[i % len(self.indexes)]
|
||||
x, filename_tokens, cond = self.dataset[index]
|
||||
entry = self.dataset[index]
|
||||
|
||||
text = self.create_text(filename_tokens)
|
||||
return x, text, cond
|
||||
if entry.cond is None:
|
||||
entry.cond_text = self.create_text(entry.filename_text)
|
||||
|
||||
return entry
|
||||
|
@ -1,6 +1,12 @@
|
||||
import tqdm
|
||||
|
||||
class LearnSchedule:
|
||||
|
||||
class LearnScheduleIterator:
|
||||
def __init__(self, learn_rate, max_steps, cur_step=0):
|
||||
"""
|
||||
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
|
||||
"""
|
||||
|
||||
pairs = learn_rate.split(',')
|
||||
self.rates = []
|
||||
self.it = 0
|
||||
@ -32,3 +38,32 @@ class LearnSchedule:
|
||||
return self.rates[self.it - 1]
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
|
||||
class LearnRateScheduler:
|
||||
def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
|
||||
self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
|
||||
(self.learn_rate, self.end_step) = next(self.schedules)
|
||||
self.verbose = verbose
|
||||
|
||||
if self.verbose:
|
||||
print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
|
||||
|
||||
self.finished = False
|
||||
|
||||
def apply(self, optimizer, step_number):
|
||||
if step_number <= self.end_step:
|
||||
return
|
||||
|
||||
try:
|
||||
(self.learn_rate, self.end_step) = next(self.schedules)
|
||||
except Exception:
|
||||
self.finished = True
|
||||
return
|
||||
|
||||
if self.verbose:
|
||||
tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
|
||||
|
||||
for pg in optimizer.param_groups:
|
||||
pg['lr'] = self.learn_rate
|
||||
|
||||
|
@ -10,7 +10,28 @@ from modules.shared import opts, cmd_opts
|
||||
if cmd_opts.deepdanbooru:
|
||||
import modules.deepbooru as deepbooru
|
||||
|
||||
|
||||
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
|
||||
try:
|
||||
if process_caption:
|
||||
shared.interrogator.load()
|
||||
|
||||
if process_caption_deepbooru:
|
||||
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, deepbooru.create_deepbooru_opts())
|
||||
|
||||
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
|
||||
|
||||
finally:
|
||||
|
||||
if process_caption:
|
||||
shared.interrogator.send_blip_to_ram()
|
||||
|
||||
if process_caption_deepbooru:
|
||||
deepbooru.release_process()
|
||||
|
||||
|
||||
|
||||
def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
|
||||
width = process_width
|
||||
height = process_height
|
||||
src = os.path.abspath(process_src)
|
||||
@ -25,30 +46,28 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
||||
shared.state.textinfo = "Preprocessing..."
|
||||
shared.state.job_count = len(files)
|
||||
|
||||
if process_caption:
|
||||
shared.interrogator.load()
|
||||
|
||||
if process_caption_deepbooru:
|
||||
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, opts.deepbooru_sort_alpha)
|
||||
|
||||
def save_pic_with_caption(image, index):
|
||||
if process_caption:
|
||||
caption = "-" + shared.interrogator.generate_caption(image)
|
||||
caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
|
||||
elif process_caption_deepbooru:
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
shared.deepbooru_process_queue.put(image)
|
||||
while shared.deepbooru_process_return["value"] == -1:
|
||||
time.sleep(0.2)
|
||||
caption = "-" + shared.deepbooru_process_return["value"]
|
||||
caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
else:
|
||||
caption = filename
|
||||
caption = os.path.splitext(caption)[0]
|
||||
caption = os.path.basename(caption)
|
||||
caption = ""
|
||||
|
||||
if process_caption:
|
||||
caption += shared.interrogator.generate_caption(image)
|
||||
|
||||
if process_caption_deepbooru:
|
||||
if len(caption) > 0:
|
||||
caption += ", "
|
||||
caption += deepbooru.get_tags_from_process(image)
|
||||
|
||||
filename_part = filename
|
||||
filename_part = os.path.splitext(filename_part)[0]
|
||||
filename_part = os.path.basename(filename_part)
|
||||
|
||||
basename = f"{index:05}-{subindex[0]}-{filename_part}"
|
||||
image.save(os.path.join(dst, f"{basename}.png"))
|
||||
|
||||
if len(caption) > 0:
|
||||
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
|
||||
file.write(caption)
|
||||
|
||||
image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
|
||||
subindex[0] += 1
|
||||
|
||||
def save_pic(image, index):
|
||||
@ -93,34 +112,3 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
||||
save_pic(img, index)
|
||||
|
||||
shared.state.nextjob()
|
||||
|
||||
if process_caption:
|
||||
shared.interrogator.send_blip_to_ram()
|
||||
|
||||
if process_caption_deepbooru:
|
||||
deepbooru.release_process()
|
||||
|
||||
|
||||
def sanitize_caption(base_path, original_caption, suffix):
|
||||
operating_system = platform.system().lower()
|
||||
if (operating_system == "windows"):
|
||||
invalid_path_characters = "\\/:*?\"<>|"
|
||||
max_path_length = 259
|
||||
else:
|
||||
invalid_path_characters = "/" #linux/macos
|
||||
max_path_length = 1023
|
||||
caption = original_caption
|
||||
for invalid_character in invalid_path_characters:
|
||||
caption = caption.replace(invalid_character, "")
|
||||
fixed_path_length = len(base_path) + len(suffix)
|
||||
if fixed_path_length + len(caption) <= max_path_length:
|
||||
return caption
|
||||
caption_tokens = caption.split()
|
||||
new_caption = ""
|
||||
for token in caption_tokens:
|
||||
last_caption = new_caption
|
||||
new_caption = new_caption + token + " "
|
||||
if (len(new_caption) + fixed_path_length - 1 > max_path_length):
|
||||
break
|
||||
print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr)
|
||||
return last_caption.strip()
|
||||
|
@ -11,7 +11,7 @@ from PIL import Image, PngImagePlugin
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnSchedule
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
|
||||
insert_image_data_embed, extract_image_data_embed,
|
||||
@ -172,8 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||
return fn
|
||||
|
||||
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
|
||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
|
||||
assert embedding_name, 'embedding not selected'
|
||||
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
@ -205,7 +204,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
||||
|
||||
hijack = sd_hijack.model_hijack
|
||||
|
||||
@ -221,32 +220,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
if ititial_step > steps:
|
||||
return embedding, filename
|
||||
|
||||
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
|
||||
(learn_rate, end_step) = next(schedules)
|
||||
print(f'Training at rate of {learn_rate} until step {end_step}')
|
||||
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||
for i, (x, text, _) in pbar:
|
||||
for i, entry in pbar:
|
||||
embedding.step = i + ititial_step
|
||||
|
||||
if embedding.step > end_step:
|
||||
try:
|
||||
(learn_rate, end_step) = next(schedules)
|
||||
except:
|
||||
break
|
||||
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
|
||||
for pg in optimizer.param_groups:
|
||||
pg['lr'] = learn_rate
|
||||
scheduler.apply(optimizer, embedding.step)
|
||||
if scheduler.finished:
|
||||
break
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
c = cond_model([text])
|
||||
c = cond_model([entry.cond_text])
|
||||
|
||||
x = x.to(devices.device)
|
||||
x = entry.latent.to(devices.device)
|
||||
loss = shared.sd_model(x.unsqueeze(0), c)[0]
|
||||
del x
|
||||
|
||||
@ -268,7 +259,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
||||
|
||||
preview_text = text if preview_image_prompt == "" else preview_image_prompt
|
||||
preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
@ -314,7 +305,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
<p>
|
||||
Loss: {losses.mean():.7f}<br/>
|
||||
Step: {embedding.step}<br/>
|
||||
Last prompt: {html.escape(text)}<br/>
|
||||
Last prompt: {html.escape(entry.cond_text)}<br/>
|
||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
|
@ -1082,11 +1082,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
with gr.Row():
|
||||
process_flip = gr.Checkbox(label='Create flipped copies')
|
||||
process_split = gr.Checkbox(label='Split oversized images into two')
|
||||
process_caption = gr.Checkbox(label='Use BLIP caption as filename')
|
||||
if cmd_opts.deepdanbooru:
|
||||
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename')
|
||||
else:
|
||||
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename', visible=False)
|
||||
process_caption = gr.Checkbox(label='Use BLIP for caption')
|
||||
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
@ -1106,7 +1103,6 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||
num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
|
||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
||||
@ -1184,7 +1180,6 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
training_width,
|
||||
training_height,
|
||||
steps,
|
||||
num_repeats,
|
||||
create_image_every,
|
||||
save_embedding_every,
|
||||
template_file,
|
||||
|
Loading…
Reference in New Issue
Block a user