muhammadsalmanalfaridzi commited on
Commit
cf51dd8
·
verified ·
1 Parent(s): ff5c315

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -0
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from dotenv import load_dotenv
3
+ from roboflow import Roboflow
4
+ import tempfile
5
+ import os
6
+ import cv2
7
+ import numpy as np
8
+ from dds_cloudapi_sdk import Config, Client
9
+ from dds_cloudapi_sdk.tasks.dinox import DinoxTask
10
+ from dds_cloudapi_sdk.tasks.types import DetectionTarget
11
+ from dds_cloudapi_sdk import TextPrompt
12
+ import supervision as sv
13
+
14
+ # ========== Configuration ==========
15
+ load_dotenv()
16
+
17
+ # Roboflow Config
18
+ rf_api_key = os.getenv("ROBOFLOW_API_KEY")
19
+ workspace = os.getenv("ROBOFLOW_WORKSPACE")
20
+ project_name = os.getenv("ROBOFLOW_PROJECT")
21
+ model_version = int(os.getenv("ROBOFLOW_MODEL_VERSION"))
22
+
23
+ # DINO-X Config
24
+ DINOX_API_KEY = os.getenv("DINO_X_API_KEY")
25
+ DINOX_PROMPT = "beverage . bottle" # Customize for competitor products
26
+
27
+ # Initialize Models
28
+ rf = Roboflow(api_key=rf_api_key)
29
+ project = rf.workspace(workspace).project(project_name)
30
+ yolo_model = project.version(model_version).model
31
+ dinox_config = Config(DINOX_API_KEY)
32
+ dinox_client = Client(dinox_config)
33
+
34
+ # ========== Combined Detection Function ==========
35
+ def detect_combined(image):
36
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
37
+ image.save(temp_file, format="JPEG")
38
+ temp_path = temp_file.name
39
+ try:
40
+ # [1] YOLO: Detect Nestlé Products
41
+ yolo_pred = yolo_model.predict(temp_path, confidence=60, overlap=80).json()
42
+ nestle_class_count = {}
43
+ nestle_boxes = []
44
+ for pred in yolo_pred['predictions']:
45
+ class_name = pred['class']
46
+ nestle_class_count[class_name] = nestle_class_count.get(class_name, 0) + 1
47
+ nestle_boxes.append((pred['x'], pred['y'], pred['width'], pred['height']))
48
+ total_nestle = sum(nestle_class_count.values())
49
+
50
+ # [2] DINO-X: Detect Competitor Products
51
+ image_url = dinox_client.upload_file(temp_path)
52
+ task = DinoxTask(
53
+ image_url=image_url,
54
+ prompts=[TextPrompt(text=DINOX_PROMPT)],
55
+ bbox_threshold=0.25,
56
+ targets=[DetectionTarget.BBox]
57
+ )
58
+ dinox_client.run_task(task)
59
+ dinox_pred = task.result.objects
60
+
61
+ # Filter & Count Competitors
62
+ competitor_class_count = {}
63
+ competitor_boxes = []
64
+ for obj in dinox_pred:
65
+ dinox_box = obj.bbox
66
+ if not is_overlap(dinox_box, nestle_boxes):
67
+ class_name = obj.category.strip().lower()
68
+ competitor_class_count[class_name] = competitor_class_count.get(class_name, 0) + 1
69
+ competitor_boxes.append({
70
+ "class": class_name,
71
+ "box": dinox_box,
72
+ "confidence": obj.score
73
+ })
74
+ total_competitor = sum(competitor_class_count.values())
75
+
76
+ # [3] Format Output
77
+ result_text = "Product Nestle\n\n"
78
+ for class_name, count in nestle_class_count.items():
79
+ result_text += f"{class_name}: {count}\n"
80
+ result_text += f"\nTotal Product Nestle: {total_nestle}\n\n"
81
+ result_text += "Competitor Products\n\n"
82
+ if competitor_class_count:
83
+ for class_name, count in competitor_class_count.items():
84
+ result_text += f"{class_name}: {count}\n"
85
+ else:
86
+ result_text += "No competitors detected\n"
87
+ result_text += f"\nTotal Competitor: {total_competitor}"
88
+
89
+ # [4] Visualization
90
+ img = cv2.imread(temp_path)
91
+ # Nestlé (Green)
92
+ for pred in yolo_pred['predictions']:
93
+ x, y, w, h = pred['x'], pred['y'], pred['width'], pred['height']
94
+ cv2.rectangle(img, (int(x-w/2), int(y-h/2)), (int(x+w/2), int(y+h/2)), (0, 255, 0), 2)
95
+ cv2.putText(img, pred['class'], (int(x-w/2), int(y-h/2-10)),
96
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
97
+ # Competitors (Red)
98
+ for comp in competitor_boxes:
99
+ x1, y1, x2, y2 = comp['box']
100
+ cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
101
+ cv2.putText(img, f"{comp['class']} {comp['confidence']:.2f}",
102
+ (int(x1), int(y1-10)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
103
+ output_path = "/tmp/combined_output.jpg"
104
+ cv2.imwrite(output_path, img)
105
+ return output_path, result_text
106
+ except Exception as e:
107
+ return temp_path, f"Error: {str(e)}"
108
+ finally:
109
+ if os.path.exists(temp_path):
110
+ os.remove(temp_path)
111
+
112
+ # ========== Overlap Detection Function ==========
113
+ def is_overlap(box1, boxes2, threshold=0.3):
114
+ x1_min, y1_min, x1_max, y1_max = box1
115
+ for b2 in boxes2:
116
+ x2, y2, w2, h2 = b2
117
+ x2_min = x2 - w2/2
118
+ x2_max = x2 + w2/2
119
+ y2_min = y2 - h2/2
120
+ y2_max = y2 + h2/2
121
+ dx = min(x1_max, x2_max) - max(x1_min, x2_min)
122
+ dy = min(y1_max, y2_max) - max(y1_min, y2_min)
123
+ if (dx >= 0) and (dy >= 0):
124
+ area_overlap = dx * dy
125
+ area_box1 = (x1_max - x1_min) * (y1_max - y1_min)
126
+ if area_overlap / area_box1 > threshold:
127
+ return True
128
+ return False
129
+
130
+ # ========== Video Detection Function ==========
131
+ def detect_objects_in_video(video_path):
132
+ temp_output_path = "/tmp/output_video.mp4"
133
+ temp_frames_dir = tempfile.mkdtemp()
134
+ try:
135
+ video = cv2.VideoCapture(video_path)
136
+ frame_rate = int(video.get(cv2.CAP_PROP_FPS))
137
+ frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
138
+ frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
139
+ frame_size = (frame_width, frame_height)
140
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
141
+ output_video = cv2.VideoWriter(temp_output_path, fourcc, frame_rate, frame_size)
142
+ frame_index = 0
143
+ while True:
144
+ ret, frame = video.read()
145
+ if not ret:
146
+ break
147
+ frame_path = os.path.join(temp_frames_dir, f"frame_{frame_index}.jpg")
148
+ cv2.imwrite(frame_path, frame)
149
+ predictions = yolo_model.predict(frame_path, confidence=60, overlap=80).json()
150
+ class_count = {}
151
+ for prediction in predictions['predictions']:
152
+ class_name = prediction['class']
153
+ class_count[class_name] = class_count.get(class_name, 0) + 1
154
+ text_offset = 30
155
+ y_position = 30
156
+ for class_name, count in class_count.items():
157
+ cv2.putText(frame, f"{class_name}: {count}", (10, y_position),
158
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2, cv2.LINE_AA)
159
+ y_position += text_offset
160
+ for prediction in predictions['predictions']:
161
+ x, y, w, h = prediction['x'], prediction['y'], prediction['width'], prediction['height']
162
+ cv2.rectangle(frame, (int(x - w/2), int(y - h/2)), (int(x + w/2), int(y + h/2)), (0, 255, 0), 2)
163
+ cv2.putText(frame, prediction['class'], (int(x - w/2), int(y - h/2 - 10)),
164
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
165
+ output_video.write(frame)
166
+ frame_index += 1
167
+ video.release()
168
+ output_video.release()
169
+ return temp_output_path
170
+ except Exception as e:
171
+ return None, f"An error occurred: {e}"
172
+
173
+ # ========== Gradio Interface ==========
174
+ with gr.Blocks(theme=gr.themes.Base(primary_hue="teal", secondary_hue="teal", neutral_hue="slate")) as iface:
175
+ gr.Markdown("""
176
+ <div style="text-align: center;">
177
+ <h1>NESTLE - STOCK COUNTING</h1>
178
+ </div>
179
+ """)
180
+ with gr.Row():
181
+ with gr.Column():
182
+ input_image = gr.Image(type="pil", label="Input Image")
183
+ detect_image_button = gr.Button("Detect Image")
184
+ output_image = gr.Image(label="Detect Object")
185
+ output_text = gr.Textbox(label="Counting Object")
186
+ detect_image_button.click(
187
+ fn=detect_combined,
188
+ inputs=input_image,
189
+ outputs=[output_image, output_text]
190
+ )
191
+ with gr.Column():
192
+ input_video = gr.Video(label="Input Video")
193
+ detect_video_button = gr.Button("Detect Video")
194
+ output_video = gr.Video(label="Output Video")
195
+ detect_video_button.click(
196
+ fn=detect_objects_in_video,
197
+ inputs=input_video,
198
+ outputs=[output_video]
199
+ )
200
+ iface.launch()