From 4d5f1691dda971ec7b461dd880426300fd54ccee Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 28 Nov 2022 21:36:35 -0500 Subject: [PATCH] Use devices.autocast instead of torch.autocast --- modules/hypernetworks/hypernetwork.py | 2 +- modules/interrogate.py | 3 +-- modules/swinir_model.py | 6 +----- modules/textual_inversion/dataset.py | 4 ++-- modules/textual_inversion/textual_inversion.py | 2 +- 5 files changed, 6 insertions(+), 11 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 8466887f..eb5ae372 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -495,7 +495,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if shared.state.interrupted: break - with torch.autocast("cuda"): + with devices.autocast(): x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) if tag_drop_out != 0 or shuffle_tags: shared.sd_model.cond_stage_model.to(devices.device) diff --git a/modules/interrogate.py b/modules/interrogate.py index 9769aa34..40c6b082 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -148,8 +148,7 @@ class InterrogateModels: clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate) - precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext - with torch.no_grad(), precision_scope("cuda"): + with torch.no_grad(), devices.autocast(): image_features = self.clip_model.encode_image(clip_image).type(self.dtype) image_features /= image_features.norm(dim=-1, keepdim=True) diff --git a/modules/swinir_model.py b/modules/swinir_model.py index facd262d..483eabd4 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -13,10 +13,6 @@ from modules.swinir_model_arch import SwinIR as net from modules.swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData -precision_scope = ( - torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext -) - class UpscalerSwinIR(Upscaler): def __init__(self, dirname): @@ -112,7 +108,7 @@ def upscale( img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() img = img.unsqueeze(0).to(devices.device_swinir) - with torch.no_grad(), precision_scope("cuda"): + with torch.no_grad(), devices.autocast(): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old w_pad = (w_old // window_size + 1) * window_size - w_old diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index e5725f33..2dc64c3c 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -82,7 +82,7 @@ class PersonalizedBase(Dataset): torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32) latent_sample = None - with torch.autocast("cuda"): + with devices.autocast(): latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): @@ -101,7 +101,7 @@ class PersonalizedBase(Dataset): entry.cond_text = self.create_text(filename_text) if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): - with torch.autocast("cuda"): + with devices.autocast(): entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) self.dataset.append(entry) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4eb75cb5..daf8d1b8 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -316,7 +316,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ if shared.state.interrupted: break - with torch.autocast("cuda"): + with devices.autocast(): # c = stack_conds(batch.cond).to(devices.device) # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory) # print(mask)