Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,8 +8,8 @@ import gradio as gr
|
|
8 |
from langchain_core.messages import HumanMessage
|
9 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
10 |
|
11 |
-
# β
Set up Google API Key
|
12 |
-
os.environ["GOOGLE_API_KEY"] = "
|
13 |
|
14 |
# β
Initialize the Gemini model
|
15 |
gemini_model = ChatGoogleGenerativeAI(model="gemini-1.5-flash")
|
@@ -18,15 +18,13 @@ gemini_model = ChatGoogleGenerativeAI(model="gemini-1.5-flash")
|
|
18 |
yolo_model = YOLO("best.pt")
|
19 |
names = yolo_model.names # Class names from the YOLO model
|
20 |
|
21 |
-
processed_ids = set() # Store track IDs of processed bottles
|
22 |
-
|
23 |
def encode_image_to_base64(image):
|
24 |
_, img_buffer = cv2.imencode('.jpg', image)
|
25 |
return base64.b64encode(img_buffer).decode('utf-8')
|
26 |
|
27 |
def analyze_image_with_gemini(image):
|
28 |
-
if image is None:
|
29 |
-
return "
|
30 |
|
31 |
image_data = encode_image_to_base64(image)
|
32 |
message = HumanMessage(content=[
|
@@ -51,16 +49,22 @@ def process_video(video_path):
|
|
51 |
if not cap.isOpened():
|
52 |
return "Error: Could not open video file."
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
while True:
|
59 |
ret, frame = cap.read()
|
60 |
if not ret:
|
61 |
break
|
62 |
|
63 |
-
frame = cv2.resize(frame, (
|
64 |
results = yolo_model.track(frame, persist=True)
|
65 |
|
66 |
if results[0].boxes is not None:
|
@@ -70,26 +74,25 @@ def process_video(video_path):
|
|
70 |
|
71 |
for box, track_id, class_id in zip(boxes, track_ids, class_ids):
|
72 |
x1, y1, x2, y2 = box
|
73 |
-
|
|
|
74 |
|
75 |
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
76 |
cvzone.putTextRect(frame, f'ID: {track_id}', (x2, y2), 1, 1)
|
77 |
cvzone.putTextRect(frame, f'{names[class_id]}', (x1, y1), 1, 1)
|
78 |
|
79 |
-
#
|
80 |
-
cv2.line(frame, (0, center_line), (frame.shape[1], center_line), (0, 0, 255), 2)
|
81 |
-
|
82 |
-
# Check if the bottle crosses the center line
|
83 |
-
if center_y >= center_line and track_id not in processed_ids:
|
84 |
-
processed_ids.add(track_id)
|
85 |
crop = frame[y1:y2, x1:x2]
|
86 |
response = analyze_image_with_gemini(crop)
|
87 |
-
|
|
|
88 |
|
89 |
-
|
90 |
|
91 |
cap.release()
|
92 |
-
|
|
|
|
|
93 |
|
94 |
def gradio_interface(video_path):
|
95 |
if video_path is None:
|
@@ -100,7 +103,7 @@ def gradio_interface(video_path):
|
|
100 |
iface = gr.Interface(
|
101 |
fn=gradio_interface,
|
102 |
inputs=gr.File(type="filepath", label="Upload Video"),
|
103 |
-
outputs=gr.
|
104 |
title="YOLO + Gemini AI Video Analysis",
|
105 |
description="Upload a video to detect objects and analyze them using Gemini AI.",
|
106 |
)
|
|
|
8 |
from langchain_core.messages import HumanMessage
|
9 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
10 |
|
11 |
+
# β
Set up Google API Key (Avoid hardcoding in production)
|
12 |
+
os.environ["GOOGLE_API_KEY"] = "YOUR_GOOGLE_API_KEY"
|
13 |
|
14 |
# β
Initialize the Gemini model
|
15 |
gemini_model = ChatGoogleGenerativeAI(model="gemini-1.5-flash")
|
|
|
18 |
yolo_model = YOLO("best.pt")
|
19 |
names = yolo_model.names # Class names from the YOLO model
|
20 |
|
|
|
|
|
21 |
def encode_image_to_base64(image):
|
22 |
_, img_buffer = cv2.imencode('.jpg', image)
|
23 |
return base64.b64encode(img_buffer).decode('utf-8')
|
24 |
|
25 |
def analyze_image_with_gemini(image):
|
26 |
+
if image is None or image.shape[0] == 0 or image.shape[1] == 0:
|
27 |
+
return "Error: Invalid image."
|
28 |
|
29 |
image_data = encode_image_to_base64(image)
|
30 |
message = HumanMessage(content=[
|
|
|
49 |
if not cap.isOpened():
|
50 |
return "Error: Could not open video file."
|
51 |
|
52 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
53 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
54 |
+
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
55 |
+
|
56 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
57 |
+
output_video_path = "output.mp4"
|
58 |
+
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
59 |
+
|
60 |
+
vertical_center = width // 2
|
61 |
|
62 |
while True:
|
63 |
ret, frame = cap.read()
|
64 |
if not ret:
|
65 |
break
|
66 |
|
67 |
+
frame = cv2.resize(frame, (width, height))
|
68 |
results = yolo_model.track(frame, persist=True)
|
69 |
|
70 |
if results[0].boxes is not None:
|
|
|
74 |
|
75 |
for box, track_id, class_id in zip(boxes, track_ids, class_ids):
|
76 |
x1, y1, x2, y2 = box
|
77 |
+
center_x = (x1 + x2) // 2
|
78 |
+
center_y = (y1 + y2) // 2
|
79 |
|
80 |
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
81 |
cvzone.putTextRect(frame, f'ID: {track_id}', (x2, y2), 1, 1)
|
82 |
cvzone.putTextRect(frame, f'{names[class_id]}', (x1, y1), 1, 1)
|
83 |
|
84 |
+
if abs(center_x - vertical_center) < 10: # If the center of the box is near the vertical center
|
|
|
|
|
|
|
|
|
|
|
85 |
crop = frame[y1:y2, x1:x2]
|
86 |
response = analyze_image_with_gemini(crop)
|
87 |
+
|
88 |
+
cvzone.putTextRect(frame, response, (x1, y1 - 10), 1, 1, colorT=(255, 255, 255), colorR=(0, 0, 255))
|
89 |
|
90 |
+
out.write(frame)
|
91 |
|
92 |
cap.release()
|
93 |
+
out.release()
|
94 |
+
|
95 |
+
return output_video_path
|
96 |
|
97 |
def gradio_interface(video_path):
|
98 |
if video_path is None:
|
|
|
103 |
iface = gr.Interface(
|
104 |
fn=gradio_interface,
|
105 |
inputs=gr.File(type="filepath", label="Upload Video"),
|
106 |
+
outputs=gr.Video(label="Processed Video"),
|
107 |
title="YOLO + Gemini AI Video Analysis",
|
108 |
description="Upload a video to detect objects and analyze them using Gemini AI.",
|
109 |
)
|