remove parsing command line from devices.py
This commit is contained in:
parent
e80bdcab91
commit
50b5504401
@ -15,14 +15,10 @@ def extract_device_id(args, name):
|
||||
|
||||
def get_optimal_device():
|
||||
if torch.cuda.is_available():
|
||||
# CUDA device selection support:
|
||||
if "shared" not in sys.modules:
|
||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "") #re-parse the commandline arguments because using the shared.py module creates an import loop.
|
||||
sys.argv += shlex.split(commandline_args)
|
||||
device_id = extract_device_id(sys.argv, '--device-id')
|
||||
else:
|
||||
device_id = shared.cmd_opts.device_id
|
||||
|
||||
from modules import shared
|
||||
|
||||
device_id = shared.cmd_opts.device_id
|
||||
|
||||
if device_id is not None:
|
||||
cuda_device = f"cuda:{device_id}"
|
||||
return torch.device(cuda_device)
|
||||
@ -49,7 +45,7 @@ def enable_tf32():
|
||||
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
||||
device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None
|
||||
dtype = torch.float16
|
||||
dtype_vae = torch.float16
|
||||
|
||||
|
@ -1,9 +1,8 @@
|
||||
import torch
|
||||
from modules.devices import get_optimal_device
|
||||
from modules import devices
|
||||
|
||||
module_in_gpu = None
|
||||
cpu = torch.device("cpu")
|
||||
device = gpu = get_optimal_device()
|
||||
|
||||
|
||||
def send_everything_to_cpu():
|
||||
@ -33,7 +32,7 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
if module_in_gpu is not None:
|
||||
module_in_gpu.to(cpu)
|
||||
|
||||
module.to(gpu)
|
||||
module.to(devices.device)
|
||||
module_in_gpu = module
|
||||
|
||||
# see below for register_forward_pre_hook;
|
||||
@ -51,7 +50,7 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
||||
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
|
||||
sd_model.to(device)
|
||||
sd_model.to(devices.device)
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
|
||||
|
||||
# register hooks for those the first two models
|
||||
@ -70,7 +69,7 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
# so that only one of them is in GPU at a time
|
||||
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
||||
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
||||
sd_model.model.to(device)
|
||||
sd_model.model.to(devices.device)
|
||||
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
||||
|
||||
# install hooks for bits of third model
|
||||
|
Loading…
Reference in New Issue
Block a user