capradeepgujaran's picture
Update app.py
007d795 verified
raw
history blame
11.4 kB
import cv2
import numpy as np
from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForConditionalGeneration
import torch
from PIL import Image
import faiss
from typing import List, Dict, Tuple
import logging
import gradio as gr
import tempfile
import os
import shutil
from tqdm import tqdm
from pathlib import Path
from moviepy.video.io.VideoFileClip import VideoFileClip
class VideoRAGSystem:
def __init__(self):
self.logger = self.setup_logger()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.logger.info(f"Using device: {self.device}")
# Initialize models
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(self.device)
# Vector store setup
self.frame_index = None
self.frame_data = []
self.target_size = (224, 224)
# Create directories for storing processed data
self.temp_dir = tempfile.mkdtemp()
self.frames_dir = os.path.join(self.temp_dir, "frames")
os.makedirs(self.frames_dir, exist_ok=True)
def setup_logger(self) -> logging.Logger:
logger = logging.getLogger('VideoRAGSystem')
if logger.handlers:
logger.handlers.clear()
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def split_video(self, video_path: str, timestamp_ms: int, context_seconds: int = 3) -> str:
"""Extract a clip around the specified timestamp"""
timestamp_sec = timestamp_ms / 1000
output_path = os.path.join(self.temp_dir, "clip.mp4")
with VideoFileClip(video_path) as video:
duration = video.duration
start_time = max(timestamp_sec - context_seconds, 0)
end_time = min(timestamp_sec + context_seconds, duration)
clip = video.subclip(start_time, end_time)
clip.write_videofile(output_path, audio_codec='aac')
return output_path
@torch.no_grad()
def analyze_frame(self, image: Image.Image) -> Dict:
"""Comprehensive frame analysis"""
try:
# Generate caption
inputs = self.blip_processor(image, return_tensors="pt").to(self.device)
if self.device.type == "cuda":
inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
caption = self.blip_model.generate(**inputs, max_length=50)
caption_text = self.blip_processor.decode(caption[0], skip_special_tokens=True)
# Get visual features
clip_inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device)
if self.device.type == "cuda":
clip_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in clip_inputs.items()}
features = self.clip_model.get_image_features(**clip_inputs)
return {
"caption": caption_text,
"features": features.cpu().numpy()
}
except Exception as e:
self.logger.error(f"Frame analysis error: {str(e)}")
return None
def extract_keyframes(self, video_path: str, max_frames: int = 15) -> List[Dict]:
"""Extract and analyze key frames"""
cap = cv2.VideoCapture(video_path)
frames_info = []
frame_count = 0
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
interval = max(1, total_frames // max_frames)
with tqdm(total=max_frames, desc="Analyzing frames") as pbar:
while len(frames_info) < max_frames and cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % interval == 0:
# Save frame
frame_path = os.path.join(self.frames_dir, f"frame_{frame_count}.jpg")
cv2.imwrite(frame_path, frame)
# Analyze frame
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image = Image.fromarray(frame_rgb).resize(self.target_size, Image.LANCZOS)
analysis = self.analyze_frame(image)
if analysis is not None:
frames_info.append({
"frame_number": frame_count,
"timestamp": frame_count / cap.get(cv2.CAP_PROP_FPS),
"path": frame_path,
"caption": analysis["caption"],
"features": analysis["features"]
})
pbar.update(1)
frame_count += 1
cap.release()
return frames_info
def process_video(self, video_path: str):
"""Process video and build search index"""
self.logger.info(f"Processing video: {video_path}")
try:
# Extract and analyze frames
frames_info = self.extract_keyframes(video_path)
self.frame_data = frames_info
# Build FAISS index
if frames_info:
features = np.vstack([frame["features"] for frame in frames_info])
self.frame_index = faiss.IndexFlatL2(features.shape[1])
self.frame_index.add(features)
self.logger.info(f"Processed {len(frames_info)} frames successfully")
return True
except Exception as e:
self.logger.error(f"Video processing error: {str(e)}")
return False
@torch.no_grad()
def search_frames(self, query: str, k: int = 4) -> List[Dict]:
"""Search for relevant frames based on the query"""
try:
# Process query
inputs = self.clip_processor(text=[query], return_tensors="pt").to(self.device)
if self.device.type == "cuda":
inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
query_features = self.clip_model.get_text_features(**inputs)
# Search
distances, indices = self.frame_index.search(
query_features.cpu().numpy(),
k
)
# Prepare results
results = []
for distance, idx in zip(distances[0], indices[0]):
frame_info = self.frame_data[idx].copy()
frame_info["relevance"] = float(1 / (1 + distance))
results.append(frame_info)
return results
except Exception as e:
self.logger.error(f"Search error: {str(e)}")
return []
class VideoQAInterface:
def __init__(self):
self.rag_system = VideoRAGSystem()
self.current_video = None
self.processed = False
def process_video(self, video_file):
"""Handle video upload and processing"""
try:
if video_file is None:
return "Please upload a video first.", gr.Progress(0)
self.current_video = video_file.name
success = self.rag_system.process_video(self.current_video)
if success:
self.processed = True
return "Video processed successfully! You can now ask questions.", gr.Progress(100)
else:
return "Error processing video. Please try again.", gr.Progress(0)
except Exception as e:
self.processed = False
return f"Error: {str(e)}", gr.Progress(0)
def answer_question(self, query):
"""Handle question answering"""
if not self.processed:
return None, "Please process a video first."
try:
# Search for relevant frames
results = self.rag_system.search_frames(query)
if not results:
return None, "No relevant frames found."
# Prepare output
frames = []
descriptions = []
for result in results:
# Load frame
frame = Image.open(result["path"])
frames.append(frame)
# Prepare description
desc = f"Timestamp: {result['timestamp']:.2f}s\n"
desc += f"Scene Description: {result['caption']}\n"
desc += f"Relevance Score: {result['relevance']:.2f}"
descriptions.append(desc)
# Combine descriptions
combined_desc = "\n\nFrame Analysis:\n\n"
for i, desc in enumerate(descriptions, 1):
combined_desc += f"Frame {i}:\n{desc}\n\n"
return frames, combined_desc
except Exception as e:
return None, f"Error: {str(e)}"
def create_interface(self):
"""Create Gradio interface"""
with gr.Blocks(title="Advanced Video Question Answering") as interface:
gr.Markdown("# Advanced Video Question Answering")
gr.Markdown("Upload a video and ask questions about any aspect of its content!")
with gr.Row():
video_input = gr.File(
label="Upload Video",
file_types=["video"],
)
process_button = gr.Button("Process Video")
status_output = gr.Textbox(
label="Status",
interactive=False
)
with gr.Row():
query_input = gr.Textbox(
label="Ask about the video",
placeholder="What's happening in the video?"
)
query_button = gr.Button("Search")
gallery = gr.Gallery(
label="Retrieved Frames",
show_label=True,
columns=[2],
rows=[2],
height="auto"
)
descriptions = gr.Textbox(
label="Scene Analysis",
interactive=False,
lines=10
)
process_button.click(
fn=self.process_video,
inputs=[video_input],
outputs=[status_output]
)
query_button.click(
fn=self.answer_question,
inputs=[query_input],
outputs=[gallery, descriptions]
)
return interface
# Initialize and create the interface
app = VideoQAInterface()
interface = app.create_interface()
# Launch the app
if __name__ == "__main__":
interface.launch()