Spaces:
Running
on
Zero
Running
on
Zero
import cv2 | |
import os | |
import tempfile | |
import uuid | |
from PIL import Image | |
import numpy as np | |
from typing import Dict, List, Tuple, Any, Optional | |
import time | |
from collections import defaultdict | |
from image_processor import ImageProcessor | |
from evaluation_metrics import EvaluationMetrics | |
from scene_analyzer import SceneAnalyzer | |
from detection_model import DetectionModel | |
class VideoProcessor: | |
""" | |
Handles the processing of video files, including object detection | |
and scene analysis on selected frames. | |
""" | |
def __init__(self, image_processor: ImageProcessor): | |
""" | |
Initializes the VideoProcessor. | |
Args: | |
image_processor (ImageProcessor): An initialized ImageProcessor instance. | |
""" | |
self.image_processor = image_processor | |
def process_video_file(self, | |
video_path: str, | |
model_name: str, | |
confidence_threshold: float, | |
process_interval: int = 5, | |
scene_desc_interval_sec: int = 3) -> Tuple[Optional[str], str, Dict]: | |
""" | |
Processes an uploaded video file, performs detection and periodic scene analysis, | |
and returns the path to the annotated output video file along with a summary. | |
Args: | |
video_path (str): Path to the input video file. | |
model_name (str): Name of the YOLO model to use. | |
confidence_threshold (float): Confidence threshold for object detection. | |
process_interval (int): Process every Nth frame. Defaults to 5. | |
scene_desc_interval_sec (int): Update scene description every N seconds. Defaults to 3. | |
Returns: | |
Tuple[Optional[str], str, Dict]: (Path to output video or None, Summary text, Statistics dictionary) | |
""" | |
if not video_path or not os.path.exists(video_path): | |
print(f"Error: Video file not found at {video_path}") | |
return None, "Error: Video file not found.", {} | |
print(f"Starting video processing for: {video_path}") | |
start_time = time.time() | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print(f"Error: Could not open video file {video_path}") | |
return None, "Error opening video file.", {} | |
# Get video properties | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
if fps <= 0: # Handle case where fps is not available or invalid | |
fps = 30 # Assume a default fps | |
print(f"Warning: Could not get valid FPS for video. Assuming {fps} FPS.") | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
total_frames_video = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
print(f"Video properties: {width}x{height} @ {fps:.2f} FPS, Total Frames: {total_frames_video}") | |
# Calculate description update interval in frames | |
description_update_interval_frames = int(fps * scene_desc_interval_sec) | |
if description_update_interval_frames < 1: | |
description_update_interval_frames = int(fps) # Update at least once per second if interval is too short | |
object_trackers = {} # 儲存ID與物體的映射 | |
last_detected_objects = {} # 儲存上一次檢測到的物體資訊 | |
next_object_id = 0 # 下一個可用的物體ID | |
tracking_threshold = 0.6 # 相同物體的IoU | |
object_colors = {} # 每個被追蹤的物體分配固定顏色 | |
# Setup Output Video | |
output_filename = f"processed_{uuid.uuid4().hex}_{os.path.basename(video_path)}" | |
temp_dir = tempfile.gettempdir() # Use system's temp directory | |
output_path = os.path.join(temp_dir, output_filename) | |
# Ensure the output path has a compatible extension (like .mp4) | |
if not output_path.lower().endswith(('.mp4', '.avi', '.mov')): | |
output_path += ".mp4" | |
# Use 'mp4v' for MP4, common and well-supported | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
if not out.isOpened(): | |
print(f"Error: Could not open VideoWriter for path: {output_path}") | |
cap.release() | |
return None, f"Error creating output video file at {output_path}.", {} | |
print(f"Output video will be saved to: {output_path}") | |
frame_count = 0 | |
processed_frame_count = 0 | |
all_stats = [] # Store stats for each processed frame | |
summary_lines = [] | |
last_description = "Analyzing scene..." # Initial description | |
frame_since_last_desc = description_update_interval_frames # Trigger analysis on first processed frame | |
try: | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break # End of video | |
frame_count += 1 | |
frame_since_last_desc += 1 | |
current_frame_annotated = False # Flag if this frame was processed and annotated | |
# Process frame based on interval | |
if frame_count % process_interval == 0: | |
processed_frame_count += 1 | |
print(f"Processing frame {frame_count}...") | |
current_frame_annotated = True | |
# Use ImageProcessor for single-frame tasks | |
# 1. Convert frame format BGR -> RGB -> PIL | |
try: | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
pil_image = Image.fromarray(frame_rgb) | |
except Exception as e: | |
print(f"Error converting frame {frame_count}: {e}") | |
continue # Skip this frame | |
# 2. Get appropriate model instance | |
# Confidence is passed from UI, model_name too | |
model_instance = self.image_processor.get_model_instance(model_name, confidence_threshold) | |
if not model_instance or not model_instance.is_model_loaded: | |
print(f"Error: Model {model_name} not loaded. Skipping frame {frame_count}.") | |
# Draw basic frame without annotation | |
cv2.putText(frame, f"Scene: {last_description[:80]}...", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 3, cv2.LINE_AA) | |
cv2.putText(frame, f"Scene: {last_description[:80]}...", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA) | |
out.write(frame) | |
continue | |
# 3. Perform detection | |
detection_result = model_instance.detect(pil_image) # Use PIL image | |
current_description_for_frame = last_description # Default to last known description | |
scene_analysis_result = None | |
stats = {} | |
if detection_result and hasattr(detection_result, 'boxes') and len(detection_result.boxes) > 0: | |
# Ensure SceneAnalyzer is ready within ImageProcessor | |
if not hasattr(self.image_processor, 'scene_analyzer') or self.image_processor.scene_analyzer is None: | |
print("Initializing SceneAnalyzer...") | |
# Pass class names from the current detection result | |
self.image_processor.scene_analyzer = SceneAnalyzer(class_names=detection_result.names) | |
elif self.image_processor.scene_analyzer.class_names is None: | |
# Update class names if they were missing | |
self.image_processor.scene_analyzer.class_names = detection_result.names | |
if hasattr(self.image_processor.scene_analyzer, 'spatial_analyzer'): | |
self.image_processor.scene_analyzer.spatial_analyzer.class_names = detection_result.names | |
# 4. Perform Scene Analysis (periodically) | |
if frame_since_last_desc >= description_update_interval_frames: | |
print(f"Analyzing scene at frame {frame_count} (threshold: {description_update_interval_frames} frames)...") | |
# Pass lighting_info=None for now, as it's disabled for performance | |
scene_analysis_result = self.image_processor.analyze_scene(detection_result, lighting_info=None) | |
current_description_for_frame = scene_analysis_result.get("description", last_description) | |
last_description = current_description_for_frame # Cache the new description | |
frame_since_last_desc = 0 # Reset counter | |
# 5. Calculate Statistics for this frame | |
stats = EvaluationMetrics.calculate_basic_stats(detection_result) | |
stats['frame_number'] = frame_count # Add frame number to stats | |
all_stats.append(stats) | |
# 6. Draw annotations | |
names = detection_result.names | |
boxes = detection_result.boxes.xyxy.cpu().numpy() | |
classes = detection_result.boxes.cls.cpu().numpy().astype(int) | |
confs = detection_result.boxes.conf.cpu().numpy() | |
def calculate_iou(box1, box2): | |
"""Calculate Intersection IOU value""" | |
x1_1, y1_1, x2_1, y2_1 = box1 | |
x1_2, y1_2, x2_2, y2_2 = box2 | |
xi1 = max(x1_1, x1_2) | |
yi1 = max(y1_1, y1_2) | |
xi2 = min(x2_1, x2_2) | |
yi2 = min(y2_1, y2_2) | |
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1) | |
box1_area = (x2_1 - x1_1) * (y2_1 - y1_1) | |
box2_area = (x2_2 - x1_2) * (y2_2 - y1_2) | |
union_area = box1_area + box2_area - inter_area | |
return inter_area / union_area if union_area > 0 else 0 | |
# 處理當前幀中的所有檢測 | |
current_detected_objects = {} | |
for box, cls_id, conf in zip(boxes, classes, confs): | |
x1, y1, x2, y2 = map(int, box) | |
# 查找最匹配的已追蹤物體 | |
best_match_id = None | |
best_match_iou = 0 | |
for obj_id, (old_box, old_cls_id, _) in last_detected_objects.items(): | |
if old_cls_id == cls_id: # 同一類別才比較 | |
iou = calculate_iou(box, old_box) | |
if iou > tracking_threshold and iou > best_match_iou: | |
best_match_id = obj_id | |
best_match_iou = iou | |
# 如果找到匹配,使用現有ID;否則分配新ID | |
if best_match_id is not None: | |
obj_id = best_match_id | |
else: | |
obj_id = next_object_id | |
next_object_id += 1 | |
# 使用更明顯的顏色 | |
bright_colors = [ | |
(0, 0, 255), # red | |
(0, 255, 0), # green | |
(255, 0, 0), # blue | |
(0, 255, 255), # yellow | |
(255, 0, 255), # purple | |
(255, 128, 0), # orange | |
(128, 0, 255) # purple | |
] | |
object_colors[obj_id] = bright_colors[obj_id % len(bright_colors)] | |
# update tracking info | |
current_detected_objects[obj_id] = (box, cls_id, conf) | |
color = object_colors.get(obj_id, (0, 255, 0)) # default is green | |
label = f"{names.get(cls_id, 'Unknown')}-{obj_id}: {conf:.2f}" | |
# 平滑化邊界框:如果是已知物體,與上一幀位置平均 | |
if obj_id in last_detected_objects: | |
old_box, _, _ = last_detected_objects[obj_id] | |
old_x1, old_y1, old_x2, old_y2 = map(int, old_box) | |
# 平滑係數 | |
alpha = 0.7 # current weight | |
beta = 0.3 # history weight | |
x1 = int(alpha * x1 + beta * old_x1) | |
y1 = int(alpha * y1 + beta * old_y1) | |
x2 = int(alpha * x2 + beta * old_x2) | |
y2 = int(alpha * y2 + beta * old_y2) | |
# draw box and label | |
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) | |
# add text | |
(w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2) | |
cv2.rectangle(frame, (x1, y1 - h - 10), (x1 + w, y1 - 10), color, -1) | |
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA) | |
# update tracking info | |
last_detected_objects = current_detected_objects.copy() | |
# Draw the current scene description on the frame | |
cv2.putText(frame, f"Scene: {current_description_for_frame[:80]}...", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 3, cv2.LINE_AA) # Black outline | |
cv2.putText(frame, f"Scene: {current_description_for_frame[:80]}...", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA) # White text | |
# Write the frame (annotated or original) to the output video | |
# Draw last known description if this frame wasn't processed | |
if not current_frame_annotated: | |
cv2.putText(frame, f"Scene: {last_description[:80]}...", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 3, cv2.LINE_AA) | |
cv2.putText(frame, f"Scene: {last_description[:80]}...", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA) | |
out.write(frame) # Write frame to output file | |
except Exception as e: | |
print(f"Error during video processing loop for {video_path}: {e}") | |
import traceback | |
traceback.print_exc() | |
summary_lines.append(f"An error occurred during processing: {e}") | |
finally: | |
# Release resources | |
cap.release() | |
out.release() | |
print(f"Video processing finished. Resources released. Output path: {output_path}") | |
if not os.path.exists(output_path) or os.path.getsize(output_path) == 0: | |
print(f"Error: Output video file was not created or is empty at {output_path}") | |
summary_lines.append("Error: Failed to create output video.") | |
output_path = None | |
end_time = time.time() | |
processing_time = end_time - start_time | |
summary_lines.insert(0, f"Finished processing in {processing_time:.2f} seconds.") | |
summary_lines.insert(1, f"Processed {processed_frame_count} frames out of {frame_count} (interval: {process_interval} frames).") | |
summary_lines.insert(2, f"Scene description updated approximately every {scene_desc_interval_sec} seconds.") | |
# Generate Aggregate Statistics | |
aggregated_stats = { | |
"total_frames_read": frame_count, | |
"total_frames_processed": processed_frame_count, | |
"avg_objects_per_processed_frame": 0, # Calculate below | |
"cumulative_detections": {}, # Total times each class was detected | |
"max_concurrent_detections": {} # Max count of each class in a single processed frame | |
} | |
object_cumulative_counts = {} | |
object_max_concurrent_counts = {} # Store the max count found for each object type | |
total_detected_in_processed = 0 | |
# Iterate through stats collected from each processed frame | |
for frame_stats in all_stats: | |
total_objects_in_frame = frame_stats.get("total_objects", 0) | |
total_detected_in_processed += total_objects_in_frame | |
# Iterate through object classes detected in this frame | |
for obj_name, obj_data in frame_stats.get("class_statistics", {}).items(): | |
count_in_frame = obj_data.get("count", 0) | |
# Cumulative count | |
if obj_name not in object_cumulative_counts: | |
object_cumulative_counts[obj_name] = 0 | |
object_cumulative_counts[obj_name] += count_in_frame | |
# Max concurrent count | |
if obj_name not in object_max_concurrent_counts: | |
object_max_concurrent_counts[obj_name] = 0 | |
# Update the max count if the current frame's count is higher | |
object_max_concurrent_counts[obj_name] = max(object_max_concurrent_counts[obj_name], count_in_frame) | |
# Add sorted results to the final dictionary | |
aggregated_stats["cumulative_detections"] = dict(sorted(object_cumulative_counts.items(), key=lambda item: item[1], reverse=True)) | |
aggregated_stats["max_concurrent_detections"] = dict(sorted(object_max_concurrent_counts.items(), key=lambda item: item[1], reverse=True)) | |
# Calculate average objects per processed frame | |
if processed_frame_count > 0: | |
aggregated_stats["avg_objects_per_processed_frame"] = round(total_detected_in_processed / processed_frame_count, 2) | |
summary_text = "\n".join(summary_lines) | |
print("Generated Summary:\n", summary_text) | |
print("Aggregated Stats (Revised):\n", aggregated_stats) # Print the revised stats | |
# Return the potentially updated output_path | |
return output_path, summary_text, aggregated_stats | |