diff --git a/modules/api/api.py b/modules/api/api.py index 648bd6a8..efcedbba 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -4,7 +4,7 @@ import time import uvicorn from threading import Lock from io import BytesIO -from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image +from gradio.processing_utils import decode_base64_to_file from fastapi import APIRouter, Depends, FastAPI, HTTPException from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest @@ -41,6 +41,10 @@ def setUpscalers(req: dict): reqDict.pop('upscaler_2') return reqDict +def decode_base64_to_image(encoding): + if encoding.startswith("data:image/"): + encoding = encoding.split(";")[1].split(",")[1] + return Image.open(BytesIO(base64.b64decode(encoding))) def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: @@ -134,10 +138,7 @@ class Api: mask = img2imgreq.mask if mask: - if mask.startswith("data:image/"): - mask = decode_base64_to_image(mask) - else: - mask = Image.open(BytesIO(base64.b64decode(mask))) + mask = decode_base64_to_image(mask) populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, @@ -151,10 +152,7 @@ class Api: imgs = [] for img in init_images: - if img.startswith("data:image/"): - img = decode_base64_to_image(img) - else: - img = Image.open(BytesIO(base64.b64decode(img))) + img = decode_base64_to_image(img) imgs = [img] * p.batch_size p.init_images = imgs