Unload checkpoints on Request
…to free VRAM. New Action buttons in the settings to manually free and reload checkpoints, essentially juggling models between RAM and VRAM.
This commit is contained in:
parent
a9fed7c364
commit
4cbbb881ee
@ -18,7 +18,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
|||||||
from modules.textual_inversion.preprocess import preprocess
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin,Image
|
from PIL import PngImagePlugin,Image
|
||||||
from modules.sd_models import checkpoints_list
|
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
@ -150,6 +150,8 @@ class Api:
|
|||||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
||||||
|
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||||
|
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
|
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
|
||||||
|
|
||||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||||
@ -412,6 +414,16 @@ class Api:
|
|||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def unloadapi(self):
|
||||||
|
unload_model_weights()
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def reloadapi(self):
|
||||||
|
reload_model_weights()
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
def skip(self):
|
def skip(self):
|
||||||
shared.state.skip()
|
shared.state.skip()
|
||||||
|
|
||||||
|
@ -494,7 +494,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||||
del sd_model
|
del sd_model
|
||||||
checkpoints_loaded.clear()
|
checkpoints_loaded.clear()
|
||||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
|
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||||
return shared.sd_model
|
return shared.sd_model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -517,3 +517,23 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
print(f"Weights loaded in {timer.summary()}.")
|
print(f"Weights loaded in {timer.summary()}.")
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
def unload_model_weights(sd_model=None, info=None):
|
||||||
|
from modules import lowvram, devices, sd_hijack
|
||||||
|
timer = Timer()
|
||||||
|
|
||||||
|
if shared.sd_model:
|
||||||
|
|
||||||
|
# shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
# shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
shared.sd_model.to(devices.cpu)
|
||||||
|
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||||
|
shared.sd_model = None
|
||||||
|
sd_model = None
|
||||||
|
gc.collect()
|
||||||
|
devices.torch_gc()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
print(f"Unloaded weights {timer.summary()}.")
|
||||||
|
|
||||||
|
return sd_model
|
@ -1491,11 +1491,33 @@ def create_ui():
|
|||||||
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
||||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
||||||
|
with gr.Row():
|
||||||
|
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
||||||
|
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
||||||
|
|
||||||
with gr.TabItem("Licenses"):
|
with gr.TabItem("Licenses"):
|
||||||
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
||||||
|
|
||||||
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
||||||
|
|
||||||
|
|
||||||
|
def unload_sd_weights():
|
||||||
|
modules.sd_models.unload_model_weights()
|
||||||
|
|
||||||
|
def reload_sd_weights():
|
||||||
|
modules.sd_models.reload_model_weights()
|
||||||
|
|
||||||
|
unload_sd_model.click(
|
||||||
|
fn=unload_sd_weights,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
reload_sd_model.click(
|
||||||
|
fn=reload_sd_weights,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[]
|
||||||
|
)
|
||||||
|
|
||||||
request_notifications.click(
|
request_notifications.click(
|
||||||
fn=lambda: None,
|
fn=lambda: None,
|
||||||
|
Loading…
Reference in New Issue
Block a user