API: use finally: for state.end()
This commit is contained in:
parent
f44feb6a10
commit
e430344347
@ -602,37 +602,35 @@ class Api:
|
||||
shared.state.begin(job="create_embedding")
|
||||
filename = create_embedding(**args) # create empty embedding
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||
shared.state.end()
|
||||
return models.CreateResponse(info=f"create embedding filename: {filename}")
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return models.TrainResponse(info=f"create embedding error: {e}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
|
||||
|
||||
def create_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin(job="create_hypernetwork")
|
||||
filename = create_hypernetwork(**args) # create empty embedding
|
||||
shared.state.end()
|
||||
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return models.TrainResponse(info=f"create hypernetwork error: {e}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
|
||||
def preprocess(self, args: dict):
|
||||
try:
|
||||
shared.state.begin(job="preprocess")
|
||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||
shared.state.end()
|
||||
return models.PreprocessResponse(info = 'preprocess complete')
|
||||
return models.PreprocessResponse(info='preprocess complete')
|
||||
except KeyError as e:
|
||||
shared.state.end()
|
||||
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
except Exception as e:
|
||||
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
||||
except FileNotFoundError as e:
|
||||
finally:
|
||||
shared.state.end()
|
||||
return models.PreprocessResponse(info=f'preprocess error: {e}')
|
||||
|
||||
def train_embedding(self, args: dict):
|
||||
try:
|
||||
@ -649,11 +647,11 @@ class Api:
|
||||
finally:
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||
except AssertionError as msg:
|
||||
shared.state.end()
|
||||
except Exception as msg:
|
||||
return models.TrainResponse(info=f"train embedding error: {msg}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
|
||||
def train_hypernetwork(self, args: dict):
|
||||
try:
|
||||
@ -675,9 +673,10 @@ class Api:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||
except AssertionError:
|
||||
except Exception as exc:
|
||||
return models.TrainResponse(info=f"train embedding error: {exc}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
return models.TrainResponse(info=f"train embedding error: {error}")
|
||||
|
||||
def get_memory(self):
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user