use selected device instead of always cuda for UniPC sampler
This commit is contained in:
parent
a11ce2b96c
commit
f261a4a53c
@ -3,7 +3,8 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
|
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
|
||||||
from modules import shared
|
from modules import shared, devices
|
||||||
|
|
||||||
|
|
||||||
class UniPCSampler(object):
|
class UniPCSampler(object):
|
||||||
def __init__(self, model, **kwargs):
|
def __init__(self, model, **kwargs):
|
||||||
@ -16,8 +17,8 @@ class UniPCSampler(object):
|
|||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
if type(attr) == torch.Tensor:
|
if type(attr) == torch.Tensor:
|
||||||
if attr.device != torch.device("cuda"):
|
if attr.device != devices.device:
|
||||||
attr = attr.to(torch.device("cuda"))
|
attr = attr.to(devices.device)
|
||||||
setattr(self, name, attr)
|
setattr(self, name, attr)
|
||||||
|
|
||||||
def set_hooks(self, before_sample, after_sample, after_update):
|
def set_hooks(self, before_sample, after_sample, after_update):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user