Update detection display
Browse files- app.py +18 -23
- requirements.txt +0 -2
app.py
CHANGED
@@ -3,9 +3,7 @@ import socket
|
|
3 |
import time
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
6 |
-
from PIL import Image
|
7 |
-
import supervision as sv
|
8 |
-
import cv2
|
9 |
import base64
|
10 |
import requests
|
11 |
import json
|
@@ -23,13 +21,14 @@ LINE_WIDTH = 2
|
|
23 |
print(f"Gradio version: {gr.__version__}")
|
24 |
|
25 |
# Define the inference function
|
26 |
-
def predict_image(
|
27 |
|
28 |
-
if isinstance(
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
33 |
|
34 |
# Encode the image data as base64
|
35 |
image_base64 = base64.b64encode(np.ascontiguousarray(img)).decode()
|
@@ -59,22 +58,18 @@ def predict_image(img, threshold):
|
|
59 |
detections = json_data['detections']
|
60 |
duration = json_data['duration']
|
61 |
|
62 |
-
#
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
scene=cv2_img,
|
69 |
-
detections=detections
|
70 |
-
)
|
71 |
-
image_with_predictions_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
|
77 |
-
return
|
78 |
|
79 |
|
80 |
# Define example images and their true labels for users to choose from
|
|
|
3 |
import time
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
6 |
+
from PIL import Image, ImageDraw
|
|
|
|
|
7 |
import base64
|
8 |
import requests
|
9 |
import json
|
|
|
21 |
print(f"Gradio version: {gr.__version__}")
|
22 |
|
23 |
# Define the inference function
|
24 |
+
def predict_image(image, threshold):
|
25 |
|
26 |
+
if not isinstance(image, Image.Image):
|
27 |
+
raise BaseException("predit_image(): input 'image' shoud be single RGB image in PIL format.")
|
28 |
+
|
29 |
+
img = np.array(image)
|
30 |
+
if len(img.shape) != 3 or img.shape[2] != 3:
|
31 |
+
raise BaseException("predit_image(): input 'image' shoud be single RGB image in PIL format.")
|
32 |
|
33 |
# Encode the image data as base64
|
34 |
image_base64 = base64.b64encode(np.ascontiguousarray(img)).decode()
|
|
|
58 |
detections = json_data['detections']
|
59 |
duration = json_data['duration']
|
60 |
|
61 |
+
# drow boxes on image
|
62 |
+
draw = ImageDraw.Draw(image)
|
63 |
+
|
64 |
+
for (class_name, coords, confidence) in detections:
|
65 |
+
if len(coords) != 4:
|
66 |
+
raise ValueError("Each detection should be a polygon with 8 coordinates (xyxyxyxy).")
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
points = [(coord[0], coord[1]) for coord in coords]
|
69 |
+
draw.polygon(points, outline="red", width=LINE_WIDTH)
|
70 |
+
draw.text((points[0][0], points[0][1]), class_name, fill="red")
|
71 |
|
72 |
+
return image, img.shape, len(detections), duration
|
73 |
|
74 |
|
75 |
# Define example images and their true labels for users to choose from
|
requirements.txt
CHANGED
@@ -1,3 +1 @@
|
|
1 |
-
opencv-python
|
2 |
-
supervision
|
3 |
requests
|
|
|
|
|
|
|
1 |
requests
|