Prathamesh1420 commited on
Commit
3265a39
Β·
verified Β·
1 Parent(s): 9855222

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -21
app.py CHANGED
@@ -8,8 +8,8 @@ import gradio as gr
8
  from langchain_core.messages import HumanMessage
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
 
11
- # βœ… Set up Google API Key
12
- os.environ["GOOGLE_API_KEY"] = "AIzaSyBDVss_CkJLMcWnOxYs3LH0Q7LDi732voE" # Replace with your actual API Key
13
 
14
  # βœ… Initialize the Gemini model
15
  gemini_model = ChatGoogleGenerativeAI(model="gemini-1.5-flash")
@@ -18,15 +18,13 @@ gemini_model = ChatGoogleGenerativeAI(model="gemini-1.5-flash")
18
  yolo_model = YOLO("best.pt")
19
  names = yolo_model.names # Class names from the YOLO model
20
 
21
- processed_ids = set() # Store track IDs of processed bottles
22
-
23
  def encode_image_to_base64(image):
24
  _, img_buffer = cv2.imencode('.jpg', image)
25
  return base64.b64encode(img_buffer).decode('utf-8')
26
 
27
  def analyze_image_with_gemini(image):
28
- if image is None:
29
- return "No image available for analysis."
30
 
31
  image_data = encode_image_to_base64(image)
32
  message = HumanMessage(content=[
@@ -51,16 +49,22 @@ def process_video(video_path):
51
  if not cap.isOpened():
52
  return "Error: Could not open video file."
53
 
54
- frame_list = []
55
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
56
- center_line = frame_height // 2 # Middle of the frame
 
 
 
 
 
 
57
 
58
  while True:
59
  ret, frame = cap.read()
60
  if not ret:
61
  break
62
 
63
- frame = cv2.resize(frame, (1020, 500))
64
  results = yolo_model.track(frame, persist=True)
65
 
66
  if results[0].boxes is not None:
@@ -70,26 +74,25 @@ def process_video(video_path):
70
 
71
  for box, track_id, class_id in zip(boxes, track_ids, class_ids):
72
  x1, y1, x2, y2 = box
73
- center_y = (y1 + y2) // 2 # Compute center point
 
74
 
75
  cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
76
  cvzone.putTextRect(frame, f'ID: {track_id}', (x2, y2), 1, 1)
77
  cvzone.putTextRect(frame, f'{names[class_id]}', (x1, y1), 1, 1)
78
 
79
- # Draw center line
80
- cv2.line(frame, (0, center_line), (frame.shape[1], center_line), (0, 0, 255), 2)
81
-
82
- # Check if the bottle crosses the center line
83
- if center_y >= center_line and track_id not in processed_ids:
84
- processed_ids.add(track_id)
85
  crop = frame[y1:y2, x1:x2]
86
  response = analyze_image_with_gemini(crop)
87
- print(response)
 
88
 
89
- frame_list.append(frame)
90
 
91
  cap.release()
92
- return frame_list[0] if frame_list else "Error: No frames processed."
 
 
93
 
94
  def gradio_interface(video_path):
95
  if video_path is None:
@@ -100,7 +103,7 @@ def gradio_interface(video_path):
100
  iface = gr.Interface(
101
  fn=gradio_interface,
102
  inputs=gr.File(type="filepath", label="Upload Video"),
103
- outputs=gr.Image(label="Processed Frame"),
104
  title="YOLO + Gemini AI Video Analysis",
105
  description="Upload a video to detect objects and analyze them using Gemini AI.",
106
  )
 
8
  from langchain_core.messages import HumanMessage
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
 
11
+ # βœ… Set up Google API Key (Avoid hardcoding in production)
12
+ os.environ["GOOGLE_API_KEY"] = "YOUR_GOOGLE_API_KEY"
13
 
14
  # βœ… Initialize the Gemini model
15
  gemini_model = ChatGoogleGenerativeAI(model="gemini-1.5-flash")
 
18
  yolo_model = YOLO("best.pt")
19
  names = yolo_model.names # Class names from the YOLO model
20
 
 
 
21
  def encode_image_to_base64(image):
22
  _, img_buffer = cv2.imencode('.jpg', image)
23
  return base64.b64encode(img_buffer).decode('utf-8')
24
 
25
  def analyze_image_with_gemini(image):
26
+ if image is None or image.shape[0] == 0 or image.shape[1] == 0:
27
+ return "Error: Invalid image."
28
 
29
  image_data = encode_image_to_base64(image)
30
  message = HumanMessage(content=[
 
49
  if not cap.isOpened():
50
  return "Error: Could not open video file."
51
 
52
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
53
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
54
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
55
+
56
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
57
+ output_video_path = "output.mp4"
58
+ out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
59
+
60
+ vertical_center = width // 2
61
 
62
  while True:
63
  ret, frame = cap.read()
64
  if not ret:
65
  break
66
 
67
+ frame = cv2.resize(frame, (width, height))
68
  results = yolo_model.track(frame, persist=True)
69
 
70
  if results[0].boxes is not None:
 
74
 
75
  for box, track_id, class_id in zip(boxes, track_ids, class_ids):
76
  x1, y1, x2, y2 = box
77
+ center_x = (x1 + x2) // 2
78
+ center_y = (y1 + y2) // 2
79
 
80
  cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
81
  cvzone.putTextRect(frame, f'ID: {track_id}', (x2, y2), 1, 1)
82
  cvzone.putTextRect(frame, f'{names[class_id]}', (x1, y1), 1, 1)
83
 
84
+ if abs(center_x - vertical_center) < 10: # If the center of the box is near the vertical center
 
 
 
 
 
85
  crop = frame[y1:y2, x1:x2]
86
  response = analyze_image_with_gemini(crop)
87
+
88
+ cvzone.putTextRect(frame, response, (x1, y1 - 10), 1, 1, colorT=(255, 255, 255), colorR=(0, 0, 255))
89
 
90
+ out.write(frame)
91
 
92
  cap.release()
93
+ out.release()
94
+
95
+ return output_video_path
96
 
97
  def gradio_interface(video_path):
98
  if video_path is None:
 
103
  iface = gr.Interface(
104
  fn=gradio_interface,
105
  inputs=gr.File(type="filepath", label="Upload Video"),
106
+ outputs=gr.Video(label="Processed Video"),
107
  title="YOLO + Gemini AI Video Analysis",
108
  description="Upload a video to detect objects and analyze them using Gemini AI.",
109
  )