stuff related to torch version change
This commit is contained in:
parent
9eb49b04e3
commit
ee71eee181
@ -121,12 +121,12 @@ def run_python(code, desc=None, errdesc=None):
|
|||||||
return run(f'"{python}" -c "{code}"', desc, errdesc)
|
return run(f'"{python}" -c "{code}"', desc, errdesc)
|
||||||
|
|
||||||
|
|
||||||
def run_pip(args, desc=None):
|
def run_pip(args, desc=None, live=False):
|
||||||
if skip_install:
|
if skip_install:
|
||||||
return
|
return
|
||||||
|
|
||||||
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
||||||
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
||||||
|
|
||||||
|
|
||||||
def check_run_python(code):
|
def check_run_python(code):
|
||||||
@ -225,7 +225,7 @@ def run_extensions_installers(settings_file):
|
|||||||
def prepare_environment():
|
def prepare_environment():
|
||||||
global skip_install
|
global skip_install
|
||||||
|
|
||||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118")
|
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.0 torchvision==0.15.1 --index-url https://download.pytorch.org/whl/cu118")
|
||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
|
|
||||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
|
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
|
||||||
@ -271,7 +271,7 @@ def prepare_environment():
|
|||||||
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
if platform.python_version().startswith("3.10"):
|
if platform.python_version().startswith("3.10"):
|
||||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
|
||||||
else:
|
else:
|
||||||
print("Installation of xformers is not supported in this version of Python.")
|
print("Installation of xformers is not supported in this version of Python.")
|
||||||
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# this code is adapted from the script contributed by anon from /h/
|
# this code is adapted from the script contributed by anon from /h/
|
||||||
|
|
||||||
import io
|
|
||||||
import pickle
|
import pickle
|
||||||
import collections
|
import collections
|
||||||
import sys
|
import sys
|
||||||
@ -12,11 +11,9 @@ import _codecs
|
|||||||
import zipfile
|
import zipfile
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||||
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||||
|
|
||||||
|
|
||||||
def encode(*args):
|
def encode(*args):
|
||||||
out = _codecs.encode(*args)
|
out = _codecs.encode(*args)
|
||||||
return out
|
return out
|
||||||
@ -27,7 +24,7 @@ class RestrictedUnpickler(pickle.Unpickler):
|
|||||||
|
|
||||||
def persistent_load(self, saved_id):
|
def persistent_load(self, saved_id):
|
||||||
assert saved_id[0] == 'storage'
|
assert saved_id[0] == 'storage'
|
||||||
return TypedStorage()
|
return TypedStorage(_internal=True)
|
||||||
|
|
||||||
def find_class(self, module, name):
|
def find_class(self, module, name):
|
||||||
if self.extra_handler is not None:
|
if self.extra_handler is not None:
|
||||||
|
@ -25,6 +25,6 @@ lark==1.1.2
|
|||||||
inflection==0.5.1
|
inflection==0.5.1
|
||||||
GitPython==3.1.30
|
GitPython==3.1.30
|
||||||
torchsde==0.2.5
|
torchsde==0.2.5
|
||||||
safetensors==0.3.0
|
safetensors==0.3.1
|
||||||
httpcore<=0.15
|
httpcore<=0.15
|
||||||
fastapi==0.94.0
|
fastapi==0.94.0
|
||||||
|
6
webui.py
6
webui.py
@ -21,6 +21,8 @@ import torch
|
|||||||
import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
|
import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
|
||||||
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
||||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||||
|
warnings.filterwarnings(action='ignore', category=UserWarning, message='TypedStorage is deprecated')
|
||||||
|
|
||||||
|
|
||||||
startup_timer.record("import torch")
|
startup_timer.record("import torch")
|
||||||
|
|
||||||
@ -113,7 +115,7 @@ def check_versions():
|
|||||||
if shared.cmd_opts.skip_version_check:
|
if shared.cmd_opts.skip_version_check:
|
||||||
return
|
return
|
||||||
|
|
||||||
expected_torch_version = "1.13.1"
|
expected_torch_version = "2.0.0"
|
||||||
|
|
||||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||||
errors.print_error_explanation(f"""
|
errors.print_error_explanation(f"""
|
||||||
@ -126,7 +128,7 @@ there are reports of issues with training tab on the latest version.
|
|||||||
Use --skip-version-check commandline argument to disable this check.
|
Use --skip-version-check commandline argument to disable this check.
|
||||||
""".strip())
|
""".strip())
|
||||||
|
|
||||||
expected_xformers_version = "0.0.16rc425"
|
expected_xformers_version = "0.0.17"
|
||||||
if shared.xformers_available:
|
if shared.xformers_available:
|
||||||
import xformers
|
import xformers
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user