import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from typing import List, Dict
import logging
import os
import subprocess
import json
import tempfile
import time

logger = logging.getLogger(__name__)

def _grab_best_device(use_gpu=True):
    if torch.cuda.device_count() > 0 and use_gpu:
        device = "cuda"
    else:
        device = "cpu"
    return device

def get_video_duration_seconds(video_path: str) -> float:
    """Use ffprobe to get video duration in seconds."""
    cmd = [
        "ffprobe",
        "-v", "quiet",
        "-print_format", "json",
        "-show_format",
        video_path
    ]
    result = subprocess.run(cmd, capture_output=True, text=True)
    info = json.loads(result.stdout)
    return float(info["format"]["duration"])

def format_duration(seconds: int) -> str:
    minutes = seconds // 60
    secs = seconds % 60
    return f"{minutes:02d}:{secs:02d}"

DEVICE = _grab_best_device()

logger.info(f"Using device: {DEVICE}")

class VideoAnalyzer:
    def __init__(self):
        if not torch.cuda.is_available():
            raise RuntimeError("CUDA is required but not available!")
            
        logger.info("Initializing VideoAnalyzer")
        self.model_path = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"
        logger.info(f"Loading model from {self.model_path} - Using device: {DEVICE}")
        
        # Load processor and model
        self.processor = AutoProcessor.from_pretrained(self.model_path)

        self.model = AutoModelForImageTextToText.from_pretrained(
            self.model_path,
            torch_dtype=torch.bfloat16,
            device_map=DEVICE,
            _attn_implementation="flash_attention_2",
            low_cpu_mem_usage=True,
        ).to(DEVICE)
        
        # Compile model for faster inference
        self.model = torch.compile(self.model, mode="reduce-overhead")
        logger.info(f"Model loaded and compiled on device: {self.model.device}")
        
    def analyze_segment(self, video_path: str, start_time: float) -> str:
        """Analyze a single video segment."""
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": """You are a detailed video analysis assistant. Analyze and describe:
1. People: their appearance, actions, and interactions
2. Environment: location, weather, time of day, lighting
3. Objects: key items, their positions and movements
4. Text: any visible text, signs, or captions
5. Events: what is happening in sequence
6. Visual details: colors, patterns, visual effects
Be specific about timing and details to enable searching through the video later."""}]
            },
            {
                "role": "user",
                "content": [
                    {"type": "video", "path": video_path},
                    {"type": "text", "text": """Describe this segment comprehensively. Include:
- Who appears and what are they doing?
- What is the environment and weather like?
- What objects or items are visible?
- Is there any text visible on screen?
- What actions or events are occurring?
- Note any significant visual details
Be specific about all visual elements to enable searching later."""}
                ]
            }
        ]
        
        inputs = self.processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt"
        ).to(DEVICE, dtype=torch.bfloat16)
        
        with torch.inference_mode():
            outputs = self.model.generate(
                **inputs,
                do_sample=True,
                temperature=0.7,
                max_new_tokens=256,
            )
        return self.processor.batch_decode(outputs, skip_special_tokens=True)[0].split("Assistant: ")[-1]

    def process_video(self, video_path: str, segment_length: int = 10) -> List[Dict]:
        try:
            # Create temp directory for segments
            temp_dir = tempfile.mkdtemp()
            
            # Get video duration
            duration = get_video_duration_seconds(video_path)
            total_segments = (int(duration) + segment_length - 1) // segment_length
            logger.info(f"Processing {total_segments} segments for video of length {duration:.2f} seconds")
            
            # Process video in segments
            for segment_idx in range(total_segments):
                segment_start_time = time.time()
                start_time = segment_idx * segment_length
                end_time = min(start_time + segment_length, duration)
                
                # Skip if we've reached the end
                if start_time >= duration:
                    break
                
                # Create segment - Optimized ffmpeg settings
                segment_path = os.path.join(temp_dir, f"segment_{start_time}.mp4")
                cmd = [
                        "ffmpeg",
                        "-y",
                        "-i", video_path,
                        "-ss", str(start_time),
                        "-t", str(segment_length),
                        "-c:v", "libx264",
                        "-preset", "ultrafast",  # Use ultrafast preset for speed
                        "-pix_fmt", "yuv420p",   # Ensure compatible pixel format
                        segment_path
                    ]
                
                ffmpeg_start = time.time()
                subprocess.run(cmd, check=True)
                ffmpeg_time = time.time() - ffmpeg_start
                
                # Analyze segment
                inference_start = time.time()
                description = self.analyze_segment(segment_path, start_time)
                inference_time = time.time() - inference_start
                
                # Add segment info with timestamp
                yield {
                    "timestamp": format_duration(int(start_time)),
                    "description": description,
                    "processing_times": {
                        "ffmpeg": ffmpeg_time,
                        "inference": inference_time,
                        "total": time.time() - segment_start_time
                    }
                }
                
                # Clean up segment file
                os.remove(segment_path)
                
                logger.info(
                    f"Segment {segment_idx + 1}/{total_segments} ({start_time}-{end_time}s) - "
                    f"FFmpeg: {ffmpeg_time:.2f}s, Inference: {inference_time:.2f}s"
                )
            
            # Clean up temp directory
            os.rmdir(temp_dir)
            
        except Exception as e:
            logger.error(f"Error processing video: {str(e)}", exc_info=True)
            raise