fix for unet hijack breaking the train tab

This commit is contained in:
AUTOMATIC 2023-01-25 20:11:01 +03:00
parent 789d47f832
commit 15e89ef0f6

View File

@ -36,8 +36,11 @@ th = TorchHijackForUnet()
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
if isinstance(cond, dict):
for y in cond.keys(): for y in cond.keys():
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
with devices.autocast(): with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()