add an option to unload models during hypernetwork training to save VRAM
This commit is contained in:
parent
6d09b8d1df
commit
d4ea5f4d86
@ -175,6 +175,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
|||||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||||
|
|
||||||
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
||||||
|
unload = shared.opts.unload_models_when_training
|
||||||
|
|
||||||
if save_hypernetwork_every > 0:
|
if save_hypernetwork_every > 0:
|
||||||
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
||||||
@ -188,11 +189,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
|||||||
else:
|
else:
|
||||||
images_dir = None
|
images_dir = None
|
||||||
|
|
||||||
cond_model = shared.sd_model.cond_stage_model
|
|
||||||
|
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
|
||||||
|
|
||||||
|
if unload:
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
hypernetwork = shared.loaded_hypernetwork
|
hypernetwork = shared.loaded_hypernetwork
|
||||||
weights = hypernetwork.weights()
|
weights = hypernetwork.weights()
|
||||||
@ -211,7 +214,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
|||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||||
for i, (x, text) in pbar:
|
for i, (x, text, cond) in pbar:
|
||||||
hypernetwork.step = i + ititial_step
|
hypernetwork.step = i + ititial_step
|
||||||
|
|
||||||
if hypernetwork.step > steps:
|
if hypernetwork.step > steps:
|
||||||
@ -221,11 +224,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
|||||||
break
|
break
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
c = cond_model([text])
|
cond = cond.to(devices.device)
|
||||||
|
|
||||||
x = x.to(devices.device)
|
x = x.to(devices.device)
|
||||||
loss = shared.sd_model(x.unsqueeze(0), c)[0]
|
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
|
||||||
del x
|
del x
|
||||||
|
del cond
|
||||||
|
|
||||||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
||||||
|
|
||||||
@ -244,6 +247,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
|||||||
|
|
||||||
preview_text = text if preview_image_prompt == "" else preview_image_prompt
|
preview_text = text if preview_image_prompt == "" else preview_image_prompt
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
prompt=preview_text,
|
prompt=preview_text,
|
||||||
@ -255,6 +262,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
|||||||
processed = processing.process_images(p)
|
processed = processing.process_images(p)
|
||||||
image = processed.images[0]
|
image = processed.images[0]
|
||||||
|
|
||||||
|
if unload:
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
shared.state.current_image = image
|
shared.state.current_image = image
|
||||||
image.save(last_saved_image)
|
image.save(last_saved_image)
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import gradio as gr
|
|||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
import modules.textual_inversion.preprocess
|
import modules.textual_inversion.preprocess
|
||||||
from modules import sd_hijack, shared
|
from modules import sd_hijack, shared, devices
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
|
||||||
@ -41,5 +41,7 @@ Hypernetwork saved to {html.escape(filename)}
|
|||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
shared.loaded_hypernetwork = initial_hypernetwork
|
shared.loaded_hypernetwork = initial_hypernetwork
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
|
|
||||||
|
@ -228,6 +228,10 @@ options_templates.update(options_section(('system', "System"), {
|
|||||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
|
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP form VRAM when training"),
|
||||||
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
|
"sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
|
||||||
|
@ -8,14 +8,14 @@ from torchvision import transforms
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
import tqdm
|
import tqdm
|
||||||
from modules import devices
|
from modules import devices, shared
|
||||||
import re
|
import re
|
||||||
|
|
||||||
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
|
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
|
||||||
|
|
||||||
|
|
||||||
class PersonalizedBase(Dataset):
|
class PersonalizedBase(Dataset):
|
||||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
|
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
|
||||||
|
|
||||||
self.placeholder_token = placeholder_token
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
@ -32,6 +32,8 @@ class PersonalizedBase(Dataset):
|
|||||||
|
|
||||||
assert data_root, 'dataset directory not specified'
|
assert data_root, 'dataset directory not specified'
|
||||||
|
|
||||||
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
|
|
||||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||||
print("Preparing dataset...")
|
print("Preparing dataset...")
|
||||||
for path in tqdm.tqdm(self.image_paths):
|
for path in tqdm.tqdm(self.image_paths):
|
||||||
@ -53,7 +55,13 @@ class PersonalizedBase(Dataset):
|
|||||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
||||||
init_latent = init_latent.to(devices.cpu)
|
init_latent = init_latent.to(devices.cpu)
|
||||||
|
|
||||||
self.dataset.append((init_latent, filename_tokens))
|
if include_cond:
|
||||||
|
text = self.create_text(filename_tokens)
|
||||||
|
cond = cond_model([text]).to(devices.cpu)
|
||||||
|
else:
|
||||||
|
cond = None
|
||||||
|
|
||||||
|
self.dataset.append((init_latent, filename_tokens, cond))
|
||||||
|
|
||||||
self.length = len(self.dataset) * repeats
|
self.length = len(self.dataset) * repeats
|
||||||
|
|
||||||
@ -64,6 +72,12 @@ class PersonalizedBase(Dataset):
|
|||||||
def shuffle(self):
|
def shuffle(self):
|
||||||
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
||||||
|
|
||||||
|
def create_text(self, filename_tokens):
|
||||||
|
text = random.choice(self.lines)
|
||||||
|
text = text.replace("[name]", self.placeholder_token)
|
||||||
|
text = text.replace("[filewords]", ' '.join(filename_tokens))
|
||||||
|
return text
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
@ -72,10 +86,7 @@ class PersonalizedBase(Dataset):
|
|||||||
self.shuffle()
|
self.shuffle()
|
||||||
|
|
||||||
index = self.indexes[i % len(self.indexes)]
|
index = self.indexes[i % len(self.indexes)]
|
||||||
x, filename_tokens = self.dataset[index]
|
x, filename_tokens, cond = self.dataset[index]
|
||||||
|
|
||||||
text = random.choice(self.lines)
|
text = self.create_text(filename_tokens)
|
||||||
text = text.replace("[name]", self.placeholder_token)
|
return x, text, cond
|
||||||
text = text.replace("[filewords]", ' '.join(filename_tokens))
|
|
||||||
|
|
||||||
return x, text
|
|
||||||
|
@ -201,7 +201,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
|||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||||
for i, (x, text) in pbar:
|
for i, (x, text, _) in pbar:
|
||||||
embedding.step = i + ititial_step
|
embedding.step = i + ititial_step
|
||||||
|
|
||||||
if embedding.step > steps:
|
if embedding.step > steps:
|
||||||
|
Loading…
Reference in New Issue
Block a user