API: use finally: for state.end()

This commit is contained in:
Aarni Koskela 2023-06-30 13:11:49 +03:00
parent f44feb6a10
commit e430344347

View File

@ -602,37 +602,35 @@ class Api:
shared.state.begin(job="create_embedding") shared.state.begin(job="create_embedding")
filename = create_embedding(**args) # create empty embedding filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
shared.state.end()
return models.CreateResponse(info=f"create embedding filename: {filename}") return models.CreateResponse(info=f"create embedding filename: {filename}")
except AssertionError as e: except AssertionError as e:
shared.state.end()
return models.TrainResponse(info=f"create embedding error: {e}") return models.TrainResponse(info=f"create embedding error: {e}")
finally:
shared.state.end()
def create_hypernetwork(self, args: dict): def create_hypernetwork(self, args: dict):
try: try:
shared.state.begin(job="create_hypernetwork") shared.state.begin(job="create_hypernetwork")
filename = create_hypernetwork(**args) # create empty embedding filename = create_hypernetwork(**args) # create empty embedding
shared.state.end()
return models.CreateResponse(info=f"create hypernetwork filename: {filename}") return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
except AssertionError as e: except AssertionError as e:
shared.state.end()
return models.TrainResponse(info=f"create hypernetwork error: {e}") return models.TrainResponse(info=f"create hypernetwork error: {e}")
finally:
shared.state.end()
def preprocess(self, args: dict): def preprocess(self, args: dict):
try: try:
shared.state.begin(job="preprocess") shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end() shared.state.end()
return models.PreprocessResponse(info = 'preprocess complete') return models.PreprocessResponse(info='preprocess complete')
except KeyError as e: except KeyError as e:
shared.state.end()
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}") return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
except AssertionError as e: except Exception as e:
shared.state.end()
return models.PreprocessResponse(info=f"preprocess error: {e}") return models.PreprocessResponse(info=f"preprocess error: {e}")
except FileNotFoundError as e: finally:
shared.state.end() shared.state.end()
return models.PreprocessResponse(info=f'preprocess error: {e}')
def train_embedding(self, args: dict): def train_embedding(self, args: dict):
try: try:
@ -649,11 +647,11 @@ class Api:
finally: finally:
if not apply_optimizations: if not apply_optimizations:
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()
shared.state.end()
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
except AssertionError as msg: except Exception as msg:
shared.state.end()
return models.TrainResponse(info=f"train embedding error: {msg}") return models.TrainResponse(info=f"train embedding error: {msg}")
finally:
shared.state.end()
def train_hypernetwork(self, args: dict): def train_hypernetwork(self, args: dict):
try: try:
@ -675,9 +673,10 @@ class Api:
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()
shared.state.end() shared.state.end()
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
except AssertionError: except Exception as exc:
return models.TrainResponse(info=f"train embedding error: {exc}")
finally:
shared.state.end() shared.state.end()
return models.TrainResponse(info=f"train embedding error: {error}")
def get_memory(self): def get_memory(self):
try: try: