muhammadsalmanalfaridzi commited on
Commit
f280e18
·
verified ·
1 Parent(s): d38d4bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -34
app.py CHANGED
@@ -17,8 +17,8 @@ workspace = os.getenv("ROBOFLOW_WORKSPACE")
17
  project_name = os.getenv("ROBOFLOW_PROJECT")
18
  model_version = int(os.getenv("ROBOFLOW_MODEL_VERSION"))
19
 
20
- # CountGD Config (menggantikan DINO-X)
21
- # Pastikan API key CountGD telah di-set di .env dengan key COUNTGD_API_KEY
22
  COUNTGD_API_KEY = os.getenv("COUNTGD_API_KEY")
23
 
24
  # Inisialisasi YOLO Model dari Roboflow
@@ -26,14 +26,11 @@ rf = Roboflow(api_key=rf_api_key)
26
  project = rf.workspace(workspace).project(project_name)
27
  yolo_model = project.version(model_version).model
28
 
29
- # List prompt untuk CountGD (misal: cans, bottle, mixed box)
30
- COUNTGD_PROMPTS = ["cans", "bottle", "mixed box"]
31
-
32
  # ========== Fungsi untuk Mengecek Overlap ==========
33
  def is_overlap(box1, boxes2, threshold=0.3):
34
  """
35
  Mengecek apakah box1 (format: (x_min, y_min, x_max, y_max)) overlap dengan salah satu box di boxes2.
36
- boxes2 adalah list bounding box dari YOLO dengan format (x_center, y_center, width, height).
37
  Mengembalikan True jika rasio overlap melebihi threshold.
38
  """
39
  x1_min, y1_min, x1_max, y1_max = box1
@@ -55,58 +52,65 @@ def is_overlap(box1, boxes2, threshold=0.3):
55
 
56
  # ========== Fungsi Deteksi Kombinasi ==========
57
  def detect_combined(image):
 
58
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
59
  image.save(temp_file, format="JPEG")
60
  temp_path = temp_file.name
61
-
62
  try:
63
  # ===== YOLO Detection (Produk Nestlé) =====
64
  yolo_pred = yolo_model.predict(temp_path, confidence=50, overlap=80).json()
65
  nestle_class_count = {}
66
- nestle_boxes = [] # Menyimpan bounding box YOLO dengan format (x_center, y_center, width, height)
67
  for pred in yolo_pred['predictions']:
68
  class_name = pred['class']
69
  nestle_class_count[class_name] = nestle_class_count.get(class_name, 0) + 1
70
  nestle_boxes.append((pred['x'], pred['y'], pred['width'], pred['height']))
71
  total_nestle = sum(nestle_class_count.values())
72
-
73
  # ===== CountGD Detection (Produk Kompetitor) =====
74
  url = "https://api.landing.ai/v1/tools/text-to-object-detection"
75
- files = {"image": open(temp_path, "rb")}
76
- # Menggunakan lebih dari satu prompt
77
- data = {"prompts": COUNTGD_PROMPTS, "model": "countgd"}
 
78
  headers = {"Authorization": f"Basic {COUNTGD_API_KEY}"}
79
- response = requests.post(url, files=files, data=data, headers=headers)
80
- result = response.json()
81
 
82
- competitor_class_count = {}
83
- competitor_boxes = [] # Menyimpan bounding box CountGD dengan format (x_min, y_min, x_max, y_max)
84
- if 'data' in result:
85
- # Asumsi API mengembalikan list deteksi pada data[0]
86
- for obj in result['data'][0]:
87
- if 'bounding_box' in obj:
88
- x1, y1, x2, y2 = obj['bounding_box']
89
- # Mengambil label jika tersedia, default 'unclassified'
90
- label = obj.get('label', 'unclassified')
91
- # Hanya tambahkan deteksi jika tidak overlap dengan deteksi YOLO
92
- if not is_overlap((x1, y1, x2, y2), nestle_boxes, threshold=0.3):
93
- competitor_class_count[label] = competitor_class_count.get(label, 0) + 1
94
- competitor_boxes.append((x1, y1, x2, y2))
 
 
 
 
 
 
 
95
  total_competitor = sum(competitor_class_count.values())
96
-
97
  # ===== Format Output Text =====
98
  result_text = "Product Nestlé\n\n"
99
  for class_name, count in nestle_class_count.items():
100
  result_text += f"{class_name}: {count}\n"
101
  result_text += f"\nTotal Products Nestlé: {total_nestle}\n\n"
102
  if total_competitor:
103
- result_text += "Produk Kompetitor (CountGD) :\n"
104
  for label, count in competitor_class_count.items():
105
  result_text += f"{label}: {count}\n"
106
  result_text += f"\nTotal Produk Kompetitor: {total_competitor}\n"
107
  else:
108
  result_text += "No Unclassified Products detected\n"
109
-
110
  # ===== Visualisasi =====
111
  img = cv2.imread(temp_path)
112
  # Gambar bounding box YOLO (hijau)
@@ -121,17 +125,16 @@ def detect_combined(image):
121
  for box in competitor_boxes:
122
  x1, y1, x2, y2 = box
123
  cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
124
- # Tampilkan label hasil CountGD
125
  cv2.putText(img, "unclassified", (int(x1), int(y1)-10),
126
  cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0,0,255), 3)
127
-
128
  output_path = "/tmp/combined_output.jpg"
129
  cv2.imwrite(output_path, img)
130
  return output_path, result_text
131
-
132
  except Exception as e:
133
  return temp_path, f"Error: {str(e)}"
134
-
135
  finally:
136
  if os.path.exists(temp_path):
137
  os.remove(temp_path)
@@ -157,12 +160,14 @@ def detect_objects_in_video(video_path):
157
  if not video_path:
158
  return None, f"Video conversion error: {err}"
159
 
 
160
  video = cv2.VideoCapture(video_path)
161
  frame_rate = int(video.get(cv2.CAP_PROP_FPS))
162
  frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
163
  frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
164
  frame_size = (frame_width, frame_height)
165
 
 
166
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
167
  output_video = cv2.VideoWriter(temp_output_path, fourcc, frame_rate, frame_size)
168
 
@@ -171,11 +176,14 @@ def detect_objects_in_video(video_path):
171
  if not ret:
172
  break
173
 
 
174
  frame_path = os.path.join(temp_frames_dir, f"frame_{frame_count}.jpg")
175
  cv2.imwrite(frame_path, frame)
176
 
 
177
  predictions = yolo_model.predict(frame_path, confidence=50, overlap=80).json()
178
 
 
179
  current_detections = {}
180
  for prediction in predictions['predictions']:
181
  class_name = prediction['class']
@@ -189,6 +197,7 @@ def detect_objects_in_video(video_path):
189
  cv2.putText(frame, class_name, (pt1[0], pt1[1]-10),
190
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2)
191
 
 
192
  object_counts = {}
193
  for detection_id in current_detections:
194
  cls = current_detections[detection_id]
 
17
  project_name = os.getenv("ROBOFLOW_PROJECT")
18
  model_version = int(os.getenv("ROBOFLOW_MODEL_VERSION"))
19
 
20
+ # CountGD Config (Replace DINO-X)
21
+ # Pastikan Anda sudah set COUNTGD_API_KEY di .env
22
  COUNTGD_API_KEY = os.getenv("COUNTGD_API_KEY")
23
 
24
  # Inisialisasi YOLO Model dari Roboflow
 
26
  project = rf.workspace(workspace).project(project_name)
27
  yolo_model = project.version(model_version).model
28
 
 
 
 
29
  # ========== Fungsi untuk Mengecek Overlap ==========
30
  def is_overlap(box1, boxes2, threshold=0.3):
31
  """
32
  Mengecek apakah box1 (format: (x_min, y_min, x_max, y_max)) overlap dengan salah satu box di boxes2.
33
+ boxes2 adalah list bounding box YOLO dengan format (x_center, y_center, width, height).
34
  Mengembalikan True jika rasio overlap melebihi threshold.
35
  """
36
  x1_min, y1_min, x1_max, y1_max = box1
 
52
 
53
  # ========== Fungsi Deteksi Kombinasi ==========
54
  def detect_combined(image):
55
+ # Simpan image ke file sementara
56
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
57
  image.save(temp_file, format="JPEG")
58
  temp_path = temp_file.name
59
+
60
  try:
61
  # ===== YOLO Detection (Produk Nestlé) =====
62
  yolo_pred = yolo_model.predict(temp_path, confidence=50, overlap=80).json()
63
  nestle_class_count = {}
64
+ nestle_boxes = [] # Menyimpan bounding box YOLO (format: x_center, y_center, width, height)
65
  for pred in yolo_pred['predictions']:
66
  class_name = pred['class']
67
  nestle_class_count[class_name] = nestle_class_count.get(class_name, 0) + 1
68
  nestle_boxes.append((pred['x'], pred['y'], pred['width'], pred['height']))
69
  total_nestle = sum(nestle_class_count.values())
70
+
71
  # ===== CountGD Detection (Produk Kompetitor) =====
72
  url = "https://api.landing.ai/v1/tools/text-to-object-detection"
73
+ competitor_class_count = {}
74
+ competitor_boxes = [] # Menyimpan bounding box CountGD (format: x_min, y_min, x_max, y_max)
75
+ # Daftar prompt yang akan digunakan
76
+ COUNTGD_PROMPTS = ["cans", "bottle", "mixed box"]
77
  headers = {"Authorization": f"Basic {COUNTGD_API_KEY}"}
 
 
78
 
79
+ for prompt in COUNTGD_PROMPTS:
80
+ # Untuk setiap prompt, buka file gambar dan kirim request
81
+ with open(temp_path, "rb") as f:
82
+ files = {"image": f}
83
+ data = {"prompts": [prompt], "model": "countgd"}
84
+ response = requests.post(url, files=files, data=data, headers=headers)
85
+ result = response.json()
86
+ # Cek apakah API mengembalikan data
87
+ if 'data' in result and result['data']:
88
+ detections = result['data'][0]
89
+ for obj in detections:
90
+ if 'bounding_box' in obj:
91
+ x1, y1, x2, y2 = obj['bounding_box']
92
+ countgd_box = (x1, y1, x2, y2)
93
+ # Hanya tambahkan deteksi jika tidak overlap signifikan dengan YOLO
94
+ if not is_overlap(countgd_box, nestle_boxes, threshold=0.3):
95
+ # Gunakan label dari respons jika ada, jika tidak gunakan prompt sebagai default
96
+ label = obj.get('label', prompt)
97
+ competitor_class_count[label] = competitor_class_count.get(label, 0) + 1
98
+ competitor_boxes.append(countgd_box)
99
  total_competitor = sum(competitor_class_count.values())
100
+
101
  # ===== Format Output Text =====
102
  result_text = "Product Nestlé\n\n"
103
  for class_name, count in nestle_class_count.items():
104
  result_text += f"{class_name}: {count}\n"
105
  result_text += f"\nTotal Products Nestlé: {total_nestle}\n\n"
106
  if total_competitor:
107
+ result_text += "Produk Kompetitor (CountGD):\n"
108
  for label, count in competitor_class_count.items():
109
  result_text += f"{label}: {count}\n"
110
  result_text += f"\nTotal Produk Kompetitor: {total_competitor}\n"
111
  else:
112
  result_text += "No Unclassified Products detected\n"
113
+
114
  # ===== Visualisasi =====
115
  img = cv2.imread(temp_path)
116
  # Gambar bounding box YOLO (hijau)
 
125
  for box in competitor_boxes:
126
  x1, y1, x2, y2 = box
127
  cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
 
128
  cv2.putText(img, "unclassified", (int(x1), int(y1)-10),
129
  cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0,0,255), 3)
130
+
131
  output_path = "/tmp/combined_output.jpg"
132
  cv2.imwrite(output_path, img)
133
  return output_path, result_text
134
+
135
  except Exception as e:
136
  return temp_path, f"Error: {str(e)}"
137
+
138
  finally:
139
  if os.path.exists(temp_path):
140
  os.remove(temp_path)
 
160
  if not video_path:
161
  return None, f"Video conversion error: {err}"
162
 
163
+ # Buka video untuk diproses
164
  video = cv2.VideoCapture(video_path)
165
  frame_rate = int(video.get(cv2.CAP_PROP_FPS))
166
  frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
167
  frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
168
  frame_size = (frame_width, frame_height)
169
 
170
+ # Setup VideoWriter untuk output
171
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
172
  output_video = cv2.VideoWriter(temp_output_path, fourcc, frame_rate, frame_size)
173
 
 
176
  if not ret:
177
  break
178
 
179
+ # Simpan frame untuk deteksi YOLO
180
  frame_path = os.path.join(temp_frames_dir, f"frame_{frame_count}.jpg")
181
  cv2.imwrite(frame_path, frame)
182
 
183
+ # YOLO detection pada frame
184
  predictions = yolo_model.predict(frame_path, confidence=50, overlap=80).json()
185
 
186
+ # Gambar deteksi YOLO pada frame
187
  current_detections = {}
188
  for prediction in predictions['predictions']:
189
  class_name = prediction['class']
 
197
  cv2.putText(frame, class_name, (pt1[0], pt1[1]-10),
198
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2)
199
 
200
+ # Hitung dan tampilkan jumlah deteksi pada frame
201
  object_counts = {}
202
  for detection_id in current_detections:
203
  cls = current_detections[detection_id]