Add support for PyTorch nightly and local builds

This commit is contained in:
brkirch 2023-01-03 20:43:05 -05:00
parent 3bd737767b
commit 8111b5569d
2 changed files with 29 additions and 6 deletions

View File

@ -133,8 +133,26 @@ def numpy_fix(self, *args, **kwargs):
return orig_tensor_numpy(self, *args, **kwargs) return orig_tensor_numpy(self, *args, **kwargs)
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): orig_cumsum = torch.cumsum
torch.Tensor.to = tensor_to_fix orig_Tensor_cumsum = torch.Tensor.cumsum
torch.nn.functional.layer_norm = layer_norm_fix def cumsum_fix(input, cumsum_func, *args, **kwargs):
torch.Tensor.numpy = numpy_fix if input.device.type == 'mps':
output_dtype = kwargs.get('dtype', input.dtype)
if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]):
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
return cumsum_func(input, *args, **kwargs)
if has_mps():
if version.parse(torch.__version__) < version.parse("1.13"):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
torch.Tensor.to = tensor_to_fix
torch.nn.functional.layer_norm = layer_norm_fix
torch.Tensor.numpy = numpy_fix
elif version.parse(torch.__version__) > version.parse("1.13.1"):
if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)):
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
orig_narrow = torch.narrow
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )

View File

@ -4,7 +4,7 @@ import threading
import time import time
import importlib import importlib
import signal import signal
import threading import re
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
@ -13,6 +13,11 @@ from modules import import_hook, errors
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path from modules.paths import script_path
import torch
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__version__ = re.search(r'[\d.]+', torch.__version__).group(0)
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
import modules.codeformer_model as codeformer import modules.codeformer_model as codeformer
import modules.extras import modules.extras