Merge pull request #10268 from Sakura-Luna/pbar

UniPC progress bar adjustment
This commit is contained in:
AUTOMATIC1111 2023-05-11 08:16:36 +03:00 committed by GitHub
commit fe5d988947
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,6 @@
import torch import torch
import math import math
from tqdm.auto import trange import tqdm
class NoiseScheduleVP: class NoiseScheduleVP:
@ -759,6 +759,7 @@ class UniPC:
vec_t = timesteps[0].expand((x.shape[0])) vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)] model_prev_list = [self.model_fn(x, vec_t)]
t_prev_list = [vec_t] t_prev_list = [vec_t]
with tqdm.tqdm(total=steps) as pbar:
# Init the first `order` values by lower order multistep DPM-Solver. # Init the first `order` values by lower order multistep DPM-Solver.
for init_order in range(1, order): for init_order in range(1, order):
vec_t = timesteps[init_order].expand(x.shape[0]) vec_t = timesteps[init_order].expand(x.shape[0])
@ -769,7 +770,9 @@ class UniPC:
self.after_update(x, model_x) self.after_update(x, model_x)
model_prev_list.append(model_x) model_prev_list.append(model_x)
t_prev_list.append(vec_t) t_prev_list.append(vec_t)
for step in trange(order, steps + 1): pbar.update()
for step in range(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0]) vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final: if lower_order_final:
step_order = min(order, steps + 1 - step) step_order = min(order, steps + 1 - step)
@ -793,6 +796,7 @@ class UniPC:
if model_x is None: if model_x is None:
model_x = self.model_fn(x, vec_t) model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x model_prev_list[-1] = model_x
pbar.update()
else: else:
raise NotImplementedError() raise NotImplementedError()
if denoise_to_zero: if denoise_to_zero: