change formatting to match the main program in devices.py
This commit is contained in:
parent
c62d17aee3
commit
0ab0a50f9a
@ -3,23 +3,27 @@ import contextlib
|
|||||||
import torch
|
import torch
|
||||||
from modules import errors
|
from modules import errors
|
||||||
|
|
||||||
|
|
||||||
# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
|
# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
|
||||||
# check `getattr` and try it for compatibility
|
# check `getattr` and try it for compatibility
|
||||||
def has_mps() -> bool:
|
def has_mps() -> bool:
|
||||||
if not getattr(torch, 'has_mps', False): return False
|
if not getattr(torch, 'has_mps', False):
|
||||||
|
return False
|
||||||
try:
|
try:
|
||||||
torch.zeros(1).to(torch.device("mps"))
|
torch.zeros(1).to(torch.device("mps"))
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
cpu = torch.device("cpu")
|
|
||||||
|
|
||||||
def extract_device_id(args, name):
|
def extract_device_id(args, name):
|
||||||
for x in range(len(args)):
|
for x in range(len(args)):
|
||||||
if name in args[x]: return args[x+1]
|
if name in args[x]:
|
||||||
|
return args[x + 1]
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_optimal_device():
|
def get_optimal_device():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
from modules import shared
|
from modules import shared
|
||||||
@ -52,10 +56,12 @@ def enable_tf32():
|
|||||||
|
|
||||||
errors.run(enable_tf32, "Enabling TF32")
|
errors.run(enable_tf32, "Enabling TF32")
|
||||||
|
|
||||||
|
cpu = torch.device("cpu")
|
||||||
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
|
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
dtype_vae = torch.float16
|
dtype_vae = torch.float16
|
||||||
|
|
||||||
|
|
||||||
def randn(seed, shape):
|
def randn(seed, shape):
|
||||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
||||||
if device.type == 'mps':
|
if device.type == 'mps':
|
||||||
@ -89,6 +95,11 @@ def autocast(disable=False):
|
|||||||
|
|
||||||
return torch.autocast("cuda")
|
return torch.autocast("cuda")
|
||||||
|
|
||||||
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||||
def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
|
def mps_contiguous(input_tensor, device):
|
||||||
def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)
|
return input_tensor.contiguous() if device.type == 'mps' else input_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def mps_contiguous_to(input_tensor, device):
|
||||||
|
return mps_contiguous(input_tensor, device).to(device)
|
||||||
|
Loading…
Reference in New Issue
Block a user