UniPC progress bar adjustment
This commit is contained in:
parent
22bcc7be42
commit
ae17e97898
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import math
|
import math
|
||||||
from tqdm.auto import trange
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
class NoiseScheduleVP:
|
class NoiseScheduleVP:
|
||||||
@ -757,40 +757,44 @@ class UniPC:
|
|||||||
vec_t = timesteps[0].expand((x.shape[0]))
|
vec_t = timesteps[0].expand((x.shape[0]))
|
||||||
model_prev_list = [self.model_fn(x, vec_t)]
|
model_prev_list = [self.model_fn(x, vec_t)]
|
||||||
t_prev_list = [vec_t]
|
t_prev_list = [vec_t]
|
||||||
# Init the first `order` values by lower order multistep DPM-Solver.
|
with tqdm.tqdm(total=steps) as pbar:
|
||||||
for init_order in range(1, order):
|
# Init the first `order` values by lower order multistep DPM-Solver.
|
||||||
vec_t = timesteps[init_order].expand(x.shape[0])
|
for init_order in range(1, order):
|
||||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
vec_t = timesteps[init_order].expand(x.shape[0])
|
||||||
if model_x is None:
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
||||||
model_x = self.model_fn(x, vec_t)
|
|
||||||
if self.after_update is not None:
|
|
||||||
self.after_update(x, model_x)
|
|
||||||
model_prev_list.append(model_x)
|
|
||||||
t_prev_list.append(vec_t)
|
|
||||||
for step in trange(order, steps + 1):
|
|
||||||
vec_t = timesteps[step].expand(x.shape[0])
|
|
||||||
if lower_order_final:
|
|
||||||
step_order = min(order, steps + 1 - step)
|
|
||||||
else:
|
|
||||||
step_order = order
|
|
||||||
#print('this step order:', step_order)
|
|
||||||
if step == steps:
|
|
||||||
#print('do not run corrector at the last step')
|
|
||||||
use_corrector = False
|
|
||||||
else:
|
|
||||||
use_corrector = True
|
|
||||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
|
||||||
if self.after_update is not None:
|
|
||||||
self.after_update(x, model_x)
|
|
||||||
for i in range(order - 1):
|
|
||||||
t_prev_list[i] = t_prev_list[i + 1]
|
|
||||||
model_prev_list[i] = model_prev_list[i + 1]
|
|
||||||
t_prev_list[-1] = vec_t
|
|
||||||
# We do not need to evaluate the final model value.
|
|
||||||
if step < steps:
|
|
||||||
if model_x is None:
|
if model_x is None:
|
||||||
model_x = self.model_fn(x, vec_t)
|
model_x = self.model_fn(x, vec_t)
|
||||||
model_prev_list[-1] = model_x
|
if self.after_update is not None:
|
||||||
|
self.after_update(x, model_x)
|
||||||
|
model_prev_list.append(model_x)
|
||||||
|
t_prev_list.append(vec_t)
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
|
for step in range(order, steps + 1):
|
||||||
|
vec_t = timesteps[step].expand(x.shape[0])
|
||||||
|
if lower_order_final:
|
||||||
|
step_order = min(order, steps + 1 - step)
|
||||||
|
else:
|
||||||
|
step_order = order
|
||||||
|
#print('this step order:', step_order)
|
||||||
|
if step == steps:
|
||||||
|
#print('do not run corrector at the last step')
|
||||||
|
use_corrector = False
|
||||||
|
else:
|
||||||
|
use_corrector = True
|
||||||
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
||||||
|
if self.after_update is not None:
|
||||||
|
self.after_update(x, model_x)
|
||||||
|
for i in range(order - 1):
|
||||||
|
t_prev_list[i] = t_prev_list[i + 1]
|
||||||
|
model_prev_list[i] = model_prev_list[i + 1]
|
||||||
|
t_prev_list[-1] = vec_t
|
||||||
|
# We do not need to evaluate the final model value.
|
||||||
|
if step < steps:
|
||||||
|
if model_x is None:
|
||||||
|
model_x = self.model_fn(x, vec_t)
|
||||||
|
model_prev_list[-1] = model_x
|
||||||
|
pbar.update()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if denoise_to_zero:
|
if denoise_to_zero:
|
||||||
|
Loading…
Reference in New Issue
Block a user