diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index eb5ae372..c406ffb3 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -433,7 +433,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) + old_parallel_processing_allowed = shared.parallel_processing_allowed + if unload: + shared.parallel_processing_allowed = False shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) @@ -612,10 +615,12 @@ Last saved image: {html.escape(last_saved_image)}
if shared.opts.save_optimizer_state: hypernetwork.optimizer_state_dict = optimizer.state_dict() save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) + del optimizer hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) + shared.parallel_processing_allowed = old_parallel_processing_allowed return hypernetwork, filename diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index daf8d1b8..e28c357a 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -269,6 +269,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + old_parallel_processing_allowed = shared.parallel_processing_allowed pin_memory = shared.opts.pin_memory @@ -279,6 +280,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) if unload: + shared.parallel_processing_allowed = False shared.sd_model.first_stage_model.to(devices.cpu) embedding.vec.requires_grad = True @@ -450,6 +452,7 @@ Last saved image: {html.escape(last_saved_image)}
pbar.leave = False pbar.close() shared.sd_model.first_stage_model.to(devices.device) + shared.parallel_processing_allowed = old_parallel_processing_allowed return embedding, filename