aisatsu-api / main.py
vumichien's picture
Update main.py
ad86b00
raw
history blame
2.52 kB
from ultralytics import YOLO
from base64 import b64encode
from speech_recognition import AudioFile, Recognizer
import numpy as np
from scipy.spatial import distance as dist
from typing import Union
from sahi.utils.cv import read_image_as_pil
from fastapi import FastAPI, File, UploadFile
from utils import tts, read_image_file, pil_to_base64, get_hist
from typing import Optional
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id="ultralyticsplus/yolov8s", filename='yolov8s.pt')
model = YOLO(model_path)
CLASS = model.model.names
defaul_bot_voice = "γŠγ―γ„γ‚ˆγ†γ”γ–γ„γΎγ™"
area_thres = 0.3
app = FastAPI()
@app.get("/")
def read_root():
return {"Message": "Application startup complete"}
@app.post("/aisatsu_api/")
async def predict_api(
file: UploadFile = File(...),
last_seen: Union[UploadFile, None] = File(None)
):
image = read_image_file(await file.read())
results = model.predict(image, show=False)[0]
image = read_image_as_pil(image)
masks, boxes = results.masks, results.boxes
area_image = image.width * image.height
most_close = 0
out_img = None
diff_value = 0.5
if boxes is not None:
for xyxy, conf, cls in zip(boxes.xyxy, boxes.conf, boxes.cls):
if int(cls) != 0:
continue
box = xyxy.tolist()
area_rate = (box[2] - box[0]) * (box[3] - box[1]) / area_image
if area_rate >= most_close:
out_img = image.crop(tuple(box)).resize((64, 64))
most_close = area_rate
if last_seen is not None:
last_seen = read_image_file(await last_seen.read())
if out_img is not None:
diff_value = dist.euclidean(get_hist(out_img), get_hist(last_seen))
print(most_close, diff_value)
if most_close >= area_thres and diff_value >= 0.5:
voice_bot_path = tts(defaul_bot_voice, language="ja")
image_bot_path = pil_to_base64(out_img)
io = BytesIO()
zip_filename = "final_archive.zip"
with zipfile.ZipFile(io, mode='w', compression=zipfile.ZIP_DEFLATED) as zf:
for file_path in [voice_bot_path, image_bot_path]:
zf.write(file_path)
zf.close()
return StreamingResponse(
iter([io.getvalue()]),
media_type="application/x-zip-compressed",
headers={"Content-Disposition": f"attachment;filename=%s" % zip_filename}
)
else:
return {"message": "No face detected"}