Merge pull request #3810 from royshil/roy.add_simple_interrogate_api
Add a barebones CLIP interrogate API endpoint
This commit is contained in:
commit
5302e2cdd4
@ -63,6 +63,7 @@ class Api:
|
|||||||
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
||||||
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
||||||
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
||||||
|
self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||||
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||||
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
|
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
|
||||||
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||||
@ -214,6 +215,19 @@ class Api:
|
|||||||
|
|
||||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
|
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
|
||||||
|
|
||||||
|
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
||||||
|
image_b64 = interrogatereq.image
|
||||||
|
if image_b64 is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
|
|
||||||
|
img = self.__base64_to_image(image_b64)
|
||||||
|
|
||||||
|
# Override object param
|
||||||
|
with self.queue_lock:
|
||||||
|
processed = shared.interrogator.interrogate(img)
|
||||||
|
|
||||||
|
return InterrogateResponse(caption=processed)
|
||||||
|
|
||||||
def interruptapi(self):
|
def interruptapi(self):
|
||||||
shared.state.interrupt()
|
shared.state.interrupt()
|
||||||
|
|
||||||
|
@ -65,6 +65,7 @@ class PydanticModelGenerator:
|
|||||||
|
|
||||||
self._model_name = model_name
|
self._model_name = model_name
|
||||||
self._class_data = merge_class_params(class_instance)
|
self._class_data = merge_class_params(class_instance)
|
||||||
|
|
||||||
self._model_def = [
|
self._model_def = [
|
||||||
ModelDef(
|
ModelDef(
|
||||||
field=underscore(k),
|
field=underscore(k),
|
||||||
@ -167,6 +168,12 @@ class ProgressResponse(BaseModel):
|
|||||||
state: dict = Field(title="State", description="The current state snapshot")
|
state: dict = Field(title="State", description="The current state snapshot")
|
||||||
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||||
|
|
||||||
|
class InterrogateRequest(BaseModel):
|
||||||
|
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||||
|
|
||||||
|
class InterrogateResponse(BaseModel):
|
||||||
|
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
|
||||||
|
|
||||||
fields = {}
|
fields = {}
|
||||||
for key, value in opts.data.items():
|
for key, value in opts.data.items():
|
||||||
metadata = opts.data_labels.get(key)
|
metadata = opts.data_labels.get(key)
|
||||||
@ -231,3 +238,4 @@ class ArtistItem(BaseModel):
|
|||||||
name: str = Field(title="Name")
|
name: str = Field(title="Name")
|
||||||
score: float = Field(title="Score")
|
score: float = Field(title="Score")
|
||||||
category: str = Field(title="Category")
|
category: str = Field(title="Category")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user