UniPC progress bar adjustment

This commit is contained in:
Sakura-Luna 2023-05-11 12:26:04 +08:00
parent 22bcc7be42
commit ae17e97898

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import math import math
from tqdm.auto import trange import tqdm
class NoiseScheduleVP: class NoiseScheduleVP:
@ -757,6 +757,7 @@ class UniPC:
vec_t = timesteps[0].expand((x.shape[0])) vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)] model_prev_list = [self.model_fn(x, vec_t)]
t_prev_list = [vec_t] t_prev_list = [vec_t]
with tqdm.tqdm(total=steps) as pbar:
# Init the first `order` values by lower order multistep DPM-Solver. # Init the first `order` values by lower order multistep DPM-Solver.
for init_order in range(1, order): for init_order in range(1, order):
vec_t = timesteps[init_order].expand(x.shape[0]) vec_t = timesteps[init_order].expand(x.shape[0])
@ -767,7 +768,9 @@ class UniPC:
self.after_update(x, model_x) self.after_update(x, model_x)
model_prev_list.append(model_x) model_prev_list.append(model_x)
t_prev_list.append(vec_t) t_prev_list.append(vec_t)
for step in trange(order, steps + 1): pbar.update()
for step in range(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0]) vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final: if lower_order_final:
step_order = min(order, steps + 1 - step) step_order = min(order, steps + 1 - step)
@ -791,6 +794,7 @@ class UniPC:
if model_x is None: if model_x is None:
model_x = self.model_fn(x, vec_t) model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x model_prev_list[-1] = model_x
pbar.update()
else: else:
raise NotImplementedError() raise NotImplementedError()
if denoise_to_zero: if denoise_to_zero: