statistics for pbar
This commit is contained in:
parent
40b56c9289
commit
348f89c8d4
@ -335,6 +335,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||
size = len(ds.indexes)
|
||||
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
||||
losses = torch.zeros((size,))
|
||||
previous_mean_losses = [0]
|
||||
previous_mean_loss = 0
|
||||
print("Mean loss of {} elements".format(size))
|
||||
|
||||
@ -356,7 +357,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||
for i, entries in pbar:
|
||||
hypernetwork.step = i + ititial_step
|
||||
if len(loss_dict) > 0:
|
||||
previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
|
||||
previous_mean_losses = [i[-1] for i in loss_dict.values()]
|
||||
previous_mean_loss = mean(previous_mean_losses)
|
||||
|
||||
scheduler.apply(optimizer, hypernetwork.step)
|
||||
if scheduler.finished:
|
||||
@ -391,7 +393,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||
|
||||
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
||||
raise RuntimeError("Loss diverged.")
|
||||
pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}")
|
||||
|
||||
if len(previous_mean_losses) > 1:
|
||||
std = stdev(previous_mean_losses)
|
||||
else:
|
||||
std = 0
|
||||
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
|
||||
pbar.set_description(dataset_loss_info)
|
||||
|
||||
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
|
Loading…
Reference in New Issue
Block a user