add pbar to unipc
This commit is contained in:
parent
dfeee786f9
commit
03a80f198e
@ -71,7 +71,7 @@ class UniPCSampler(object):
|
|||||||
# sampling
|
# sampling
|
||||||
C, H, W = shape
|
C, H, W = shape
|
||||||
size = (batch_size, C, H, W)
|
size = (batch_size, C, H, W)
|
||||||
print(f'Data shape for UniPC sampling is {size}')
|
# print(f'Data shape for UniPC sampling is {size}')
|
||||||
|
|
||||||
device = self.model.betas.device
|
device = self.model.betas.device
|
||||||
if x_T is None:
|
if x_T is None:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import math
|
import math
|
||||||
|
from tqdm.auto import trange
|
||||||
|
|
||||||
|
|
||||||
class NoiseScheduleVP:
|
class NoiseScheduleVP:
|
||||||
@ -750,7 +751,7 @@ class UniPC:
|
|||||||
if method == 'multistep':
|
if method == 'multistep':
|
||||||
assert steps >= order, "UniPC order must be < sampling steps"
|
assert steps >= order, "UniPC order must be < sampling steps"
|
||||||
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
||||||
print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
|
#print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
|
||||||
assert timesteps.shape[0] - 1 == steps
|
assert timesteps.shape[0] - 1 == steps
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
vec_t = timesteps[0].expand((x.shape[0]))
|
vec_t = timesteps[0].expand((x.shape[0]))
|
||||||
@ -766,7 +767,7 @@ class UniPC:
|
|||||||
self.after_update(x, model_x)
|
self.after_update(x, model_x)
|
||||||
model_prev_list.append(model_x)
|
model_prev_list.append(model_x)
|
||||||
t_prev_list.append(vec_t)
|
t_prev_list.append(vec_t)
|
||||||
for step in range(order, steps + 1):
|
for step in trange(order, steps + 1):
|
||||||
vec_t = timesteps[step].expand(x.shape[0])
|
vec_t = timesteps[step].expand(x.shape[0])
|
||||||
if lower_order_final:
|
if lower_order_final:
|
||||||
step_order = min(order, steps + 1 - step)
|
step_order = min(order, steps + 1 - step)
|
||||||
|
Loading…
Reference in New Issue
Block a user