vumichien commited on
Commit
b7f8699
Β·
1 Parent(s): 553c308

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +58 -0
main.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralyticsplus import YOLO
2
+ from base64 import b64encode
3
+ from speech_recognition import AudioFile, Recognizer
4
+ import numpy as np
5
+ from scipy.spatial import distance as dist
6
+
7
+ from sahi.utils.cv import read_image_as_pil
8
+ from fastapi import FastAPI, File, UploadFile, Form
9
+ from utils import tts, read_image_file, pil_to_base64, base64_to_pil, get_hist
10
+ from typing import Optional
11
+
12
+ model = YOLO('ultralyticsplus/yolov8s')
13
+ CLASS = model.model.names
14
+
15
+ app = FastAPI()
16
+ defaul_bot_voice = "γŠγ―γ„γ‚ˆγ†γ”γ–γ„γΎγ™"
17
+ area_thres = 0.3
18
+
19
+
20
+ @app.get("/")
21
+ def read_root():
22
+ return {"Message": "Application startup complete"}
23
+
24
+
25
+ @app.post("/aisatsu_api/")
26
+ async def predict_api(
27
+ file: UploadFile = File(...),
28
+ last_seen: Optional[str] = Form(None)
29
+ ):
30
+ image = read_image_file(await file.read())
31
+ results = model.predict(image, show=False)[0]
32
+ image = read_image_as_pil(image)
33
+ masks, boxes = results.masks, results.boxes
34
+ area_image = image.width * image.height
35
+ voice_bot = None
36
+ most_close = 0
37
+ out_img = None
38
+ diff_value = 0.5
39
+ if boxes is not None:
40
+ for xyxy, conf, cls in zip(boxes.xyxy, boxes.conf, boxes.cls):
41
+ if int(cls) != 0:
42
+ continue
43
+ box = xyxy.tolist()
44
+ area_rate = (box[2] - box[0]) * (box[3] - box[1]) / area_image
45
+ if area_rate >= most_close:
46
+ out_img = image.crop(tuple(box)).resize((128, 128))
47
+ most_close = area_rate
48
+ if last_seen is not None:
49
+ last_seen = base64_to_pil(last_seen)
50
+ if out_img is not None:
51
+ diff_value = dist.euclidean(get_hist(out_img), get_hist(last_seen))
52
+ print(most_close, diff_value)
53
+ if most_close >= area_thres and diff_value >= 0.5:
54
+ voice_bot = tts(defaul_bot_voice, language="ja")
55
+ return {
56
+ "voice": voice_bot,
57
+ "image": pil_to_base64(out_img) if out_img is not None else None
58
+ }