Sanshruth commited on
Commit
6cacbad
·
verified ·
1 Parent(s): 1c08976

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -170,7 +170,7 @@ def draw_angled_line(image, line_params, color=(0, 255, 0), thickness=2):
170
  _, _, start_point, end_point = line_params
171
  cv2.line(image, start_point, end_point, color, thickness)
172
 
173
- def process_video(confidence_threshold=0.5, selected_classes=None, stream_url=None, target_fps=30):
174
  """
175
  Processes the IP camera stream to count objects of the selected classes crossing the line.
176
  """
@@ -194,7 +194,7 @@ def process_video(confidence_threshold=0.5, selected_classes=None, stream_url=No
194
  errors.append("Error: Could not open stream.")
195
  return None, "\n".join(errors)
196
 
197
- model = YOLO(model="yolo12n.pt")
198
  crossed_objects = {}
199
  max_tracked_objects = 1000 # Maximum number of objects to track before clearing
200
 
@@ -295,7 +295,7 @@ with gr.Blocks() as demo:
295
 
296
  # Step 2: Select classes to detect
297
  gr.Markdown("### Step 2: Select Classes to Detect")
298
- model = YOLO(model="yolo12n.pt") # Load the model to get class names
299
  class_names = list(model.names.values()) # Get class names
300
  selected_classes = gr.CheckboxGroup(choices=class_names, label="Select Classes to Detect")
301
 
@@ -307,6 +307,10 @@ with gr.Blocks() as demo:
307
  gr.Markdown("### Step 4: Set Target FPS (Optional)")
308
  target_fps = gr.Slider(minimum=1, maximum=120*4, value=60, label="Target FPS")
309
 
 
 
 
 
310
  # Process the stream
311
  process_button = gr.Button("Process Stream")
312
 
@@ -317,7 +321,7 @@ with gr.Blocks() as demo:
317
  error_box = gr.Textbox(label="Errors/Warnings", interactive=False)
318
 
319
  # Event listener for processing the video
320
- process_button.click(process_video, inputs=[confidence_threshold, selected_classes, stream_url, target_fps], outputs=[output_image, error_box])
321
 
322
  # Launch the interface
323
- demo.launch(debug=True)
 
170
  _, _, start_point, end_point = line_params
171
  cv2.line(image, start_point, end_point, color, thickness)
172
 
173
+ def process_video(confidence_threshold=0.5, selected_classes=None, stream_url=None, target_fps=30, model_name="yolov8n.pt"):
174
  """
175
  Processes the IP camera stream to count objects of the selected classes crossing the line.
176
  """
 
194
  errors.append("Error: Could not open stream.")
195
  return None, "\n".join(errors)
196
 
197
+ model = YOLO(model=model_name)
198
  crossed_objects = {}
199
  max_tracked_objects = 1000 # Maximum number of objects to track before clearing
200
 
 
295
 
296
  # Step 2: Select classes to detect
297
  gr.Markdown("### Step 2: Select Classes to Detect")
298
+ model = YOLO(model="yolov8n.pt") # Load the model to get class names
299
  class_names = list(model.names.values()) # Get class names
300
  selected_classes = gr.CheckboxGroup(choices=class_names, label="Select Classes to Detect")
301
 
 
307
  gr.Markdown("### Step 4: Set Target FPS (Optional)")
308
  target_fps = gr.Slider(minimum=1, maximum=120*4, value=60, label="Target FPS")
309
 
310
+ # Step 5: Select YOLO model
311
+ gr.Markdown("### Step 5: Select YOLO Model")
312
+ model_name = gr.Dropdown(choices=["yolov8n.pt", "yolov11n.pt","yolo12n.pt"], label="Select YOLO Model", value="yolov8n.pt")
313
+
314
  # Process the stream
315
  process_button = gr.Button("Process Stream")
316
 
 
321
  error_box = gr.Textbox(label="Errors/Warnings", interactive=False)
322
 
323
  # Event listener for processing the video
324
+ process_button.click(process_video, inputs=[confidence_threshold, selected_classes, stream_url, target_fps, model_name], outputs=[output_image, error_box])
325
 
326
  # Launch the interface
327
+ demo.launch(debug=True)