remove parsing command line from devices.py
This commit is contained in:
parent
e80bdcab91
commit
50b5504401
@ -15,13 +15,9 @@ def extract_device_id(args, name):
|
|||||||
|
|
||||||
def get_optimal_device():
|
def get_optimal_device():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
# CUDA device selection support:
|
from modules import shared
|
||||||
if "shared" not in sys.modules:
|
|
||||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "") #re-parse the commandline arguments because using the shared.py module creates an import loop.
|
device_id = shared.cmd_opts.device_id
|
||||||
sys.argv += shlex.split(commandline_args)
|
|
||||||
device_id = extract_device_id(sys.argv, '--device-id')
|
|
||||||
else:
|
|
||||||
device_id = shared.cmd_opts.device_id
|
|
||||||
|
|
||||||
if device_id is not None:
|
if device_id is not None:
|
||||||
cuda_device = f"cuda:{device_id}"
|
cuda_device = f"cuda:{device_id}"
|
||||||
@ -49,7 +45,7 @@ def enable_tf32():
|
|||||||
|
|
||||||
errors.run(enable_tf32, "Enabling TF32")
|
errors.run(enable_tf32, "Enabling TF32")
|
||||||
|
|
||||||
device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
dtype_vae = torch.float16
|
dtype_vae = torch.float16
|
||||||
|
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
from modules.devices import get_optimal_device
|
from modules import devices
|
||||||
|
|
||||||
module_in_gpu = None
|
module_in_gpu = None
|
||||||
cpu = torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
device = gpu = get_optimal_device()
|
|
||||||
|
|
||||||
|
|
||||||
def send_everything_to_cpu():
|
def send_everything_to_cpu():
|
||||||
@ -33,7 +32,7 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||||||
if module_in_gpu is not None:
|
if module_in_gpu is not None:
|
||||||
module_in_gpu.to(cpu)
|
module_in_gpu.to(cpu)
|
||||||
|
|
||||||
module.to(gpu)
|
module.to(devices.device)
|
||||||
module_in_gpu = module
|
module_in_gpu = module
|
||||||
|
|
||||||
# see below for register_forward_pre_hook;
|
# see below for register_forward_pre_hook;
|
||||||
@ -51,7 +50,7 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
||||||
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
|
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
|
||||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
|
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
|
||||||
sd_model.to(device)
|
sd_model.to(devices.device)
|
||||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
|
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
|
||||||
|
|
||||||
# register hooks for those the first two models
|
# register hooks for those the first two models
|
||||||
@ -70,7 +69,7 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||||||
# so that only one of them is in GPU at a time
|
# so that only one of them is in GPU at a time
|
||||||
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
||||||
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
||||||
sd_model.model.to(device)
|
sd_model.model.to(devices.device)
|
||||||
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
||||||
|
|
||||||
# install hooks for bits of third model
|
# install hooks for bits of third model
|
||||||
|
Loading…
Reference in New Issue
Block a user