import streamlit as st
from ultralytics import YOLO
import cv2
import time
import numpy as np
import torch
from PIL import Image
import tempfile
import warnings
warnings.filterwarnings('ignore')

def get_direction(old_center, new_center, min_movement=10):
    if old_center is None or new_center is None:
        return "stationary"
    
    dx = new_center[0] - old_center[0]
    dy = new_center[1] - old_center[1]
    
    if abs(dx) < min_movement and abs(dy) < min_movement:
        return "stationary"
    
    if abs(dx) > abs(dy):
        return "right" if dx > 0 else "left"
    else:
        return "down" if dy > 0 else "up"

class ObjectTracker:
    def __init__(self):
        self.tracked_objects = {}
        self.object_count = {}
    
    def update(self, detections):
        current_objects = {}
        results = []
        
        for detection in detections:
            x1, y1, x2, y2 = detection[0:4]
            center = ((x1 + x2) // 2, (y1 + y2) // 2)
            class_id = detection[5]
            
            object_id = f"{class_id}_{len(self.object_count.get(class_id, []))}"
            
            min_dist = float('inf')
            closest_id = None
            
            for prev_id, prev_data in self.tracked_objects.items():
                if prev_id.split('_')[0] == str(class_id):
                    dist = np.sqrt((center[0] - prev_data['center'][0])**2 + 
                                 (center[1] - prev_data['center'][1])**2)
                    if dist < min_dist and dist < 100:
                        min_dist = dist
                        closest_id = prev_id
            
            if closest_id:
                object_id = closest_id
            else:
                if class_id not in self.object_count:
                    self.object_count[class_id] = []
                self.object_count[class_id].append(object_id)
            
            prev_center = self.tracked_objects.get(object_id, {}).get('center', None)
            direction = get_direction(prev_center, center)
            
            current_objects[object_id] = {
                'center': center,
                'direction': direction,
                'detection': detection
            }
            
            results.append((detection, object_id, direction))
        
        self.tracked_objects = current_objects
        return results


    
def main():
    st.title("Real-time Object Detection with Direction")
    
    uploaded_file = st.file_uploader("Choose a video file", type=['mp4', 'avi', 'mov'])
    
    start_detection = st.button("Start Detection")
    
    stop_detection = st.button("Stop Detection")
    
    if uploaded_file is not None and start_detection:
        if 'running' not in st.session_state:
            st.session_state.running = True
            
        tfile = tempfile.NamedTemporaryFile(delete=False)
        tfile.write(uploaded_file.read())
        
        with st.spinner('Loading model...'):
            model = YOLO('yolov8x.pt',verbose=False)
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            model.to(device)
        
        tracker = ObjectTracker()
        cap = cv2.VideoCapture(tfile.name)
        
        direction_colors = {
            "left": (255, 0, 0),
            "right": (0, 255, 0),
            "up": (0, 255, 255),
            "down": (0, 0, 255),
            "stationary": (128, 128, 128)
        }
        
        frame_placeholder = st.empty()
        info_placeholder = st.empty()
        
        st.success("Detection Started!")
        
        while cap.isOpened() and st.session_state.running:
            success, frame = cap.read()
            if not success:
                break
            
            results = model(frame, 
                          conf=0.25,
                          iou=0.45,
                          max_det=20,
                          verbose=False)[0]
            
            detections = []
            for box in results.boxes.data:
                x1, y1, x2, y2, conf, cls = box.tolist()
                detections.append([int(x1), int(y1), int(x2), int(y2), float(conf), int(cls)])
            
            tracked_objects = tracker.update(detections)
            
            detection_counts = {}
            
            for detection, obj_id, direction in tracked_objects:
                x1, y1, x2, y2, conf, cls = detection
                color = direction_colors.get(direction, (128, 128, 128))
                
                cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
                
                label = f"{model.names[int(cls)]} {direction} {conf:.2f}"
                font_scale = 1.2
                thickness = 3
                text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)[0]
                
                padding_y = 15
                cv2.rectangle(frame, 
                             (int(x1), int(y1) - text_size[1] - padding_y), 
                             (int(x1) + text_size[0], int(y1)), 
                             color, -1)
                
                cv2.putText(frame, label, 
                           (int(x1), int(y1) - 5), 
                           cv2.FONT_HERSHEY_SIMPLEX, 
                           font_scale, 
                           (255, 255, 255), 
                           thickness)
                
                class_name = model.names[int(cls)]
                detection_counts[class_name] = detection_counts.get(class_name, 0) + 1
            
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            
            # Update frame
            frame_placeholder.image(frame_rgb, channels="RGB", use_column_width=True)
            
            # Update detection info
            info_text = "Detected Objects:\n"
            for class_name, count in detection_counts.items():
                info_text += f"{class_name}: {count}\n"
            info_placeholder.text(info_text)
            
            # Check if stop button is pressed
            if stop_detection:
                st.session_state.running = False
                break
            
        cap.release()
        st.session_state.running = False
        st.warning("Detection Stopped")
        
    elif uploaded_file is None and start_detection:
        st.error("Please upload a video file first!")

if __name__ == "__main__":
    main()