Merge pull request #3264 from Melanpan/tensorboard
Add support for Tensorboard (training)
This commit is contained in:
commit
1849f6eb80
@ -24,7 +24,6 @@ from statistics import stdev, mean
|
|||||||
|
|
||||||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||||
|
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
multiplier = 1.0
|
multiplier = 1.0
|
||||||
activation_dict = {
|
activation_dict = {
|
||||||
@ -498,6 +497,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||||
|
|
||||||
|
if shared.opts.training_enable_tensorboard:
|
||||||
|
tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
|
||||||
|
|
||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
|
|
||||||
@ -632,6 +634,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if shared.opts.training_enable_tensorboard:
|
||||||
|
epoch_num = hypernetwork.step // len(ds)
|
||||||
|
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
|
||||||
|
|
||||||
|
textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
|
||||||
|
|
||||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
||||||
"loss": f"{loss_step:.7f}",
|
"loss": f"{loss_step:.7f}",
|
||||||
"learn_rate": scheduler.learn_rate
|
"learn_rate": scheduler.learn_rate
|
||||||
@ -673,6 +683,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||||||
|
|
||||||
processed = processing.process_images(p)
|
processed = processing.process_images(p)
|
||||||
image = processed.images[0] if len(processed.images) > 0 else None
|
image = processed.images[0] if len(processed.images) > 0 else None
|
||||||
|
|
||||||
|
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
||||||
|
textual_inversion.tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, hypernetwork.step)
|
||||||
|
|
||||||
if unload:
|
if unload:
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
@ -373,6 +373,9 @@ options_templates.update(options_section(('training', "Training"), {
|
|||||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||||
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
||||||
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
||||||
|
"training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
|
||||||
|
"training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
|
||||||
|
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
|
@ -12,6 +12,7 @@ import csv
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
|
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
@ -294,6 +295,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
|||||||
**values,
|
**values,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def tensorboard_setup(log_directory):
|
||||||
|
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
|
||||||
|
return SummaryWriter(
|
||||||
|
log_dir=os.path.join(log_directory, "tensorboard"),
|
||||||
|
flush_secs=shared.opts.training_tensorboard_flush_every)
|
||||||
|
|
||||||
|
def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
|
||||||
|
tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
|
||||||
|
tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
|
||||||
|
tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
|
||||||
|
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
|
||||||
|
|
||||||
|
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
|
||||||
|
tensorboard_writer.add_scalar(tag=tag,
|
||||||
|
scalar_value=value, global_step=step)
|
||||||
|
|
||||||
|
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
|
||||||
|
# Convert a pil image to a torch tensor
|
||||||
|
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
|
||||||
|
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
|
||||||
|
len(pil_image.getbands()))
|
||||||
|
img_tensor = img_tensor.permute((2, 0, 1))
|
||||||
|
|
||||||
|
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
|
||||||
|
|
||||||
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||||
assert model_name, f"{name} not selected"
|
assert model_name, f"{name} not selected"
|
||||||
@ -372,6 +397,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||||
|
|
||||||
|
if shared.opts.training_enable_tensorboard:
|
||||||
|
tensorboard_writer = tensorboard_setup(log_directory)
|
||||||
|
|
||||||
pin_memory = shared.opts.pin_memory
|
pin_memory = shared.opts.pin_memory
|
||||||
|
|
||||||
@ -535,6 +563,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
|
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
||||||
|
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
|
||||||
|
|
||||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||||
|
|
||||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
||||||
|
Loading…
Reference in New Issue
Block a user