|
import cv2 |
|
import numpy as np |
|
from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForConditionalGeneration |
|
import torch |
|
from PIL import Image |
|
import faiss |
|
from typing import List, Dict, Tuple |
|
import logging |
|
import gradio as gr |
|
import tempfile |
|
import os |
|
import shutil |
|
from tqdm import tqdm |
|
from pathlib import Path |
|
from moviepy.video.io.VideoFileClip import VideoFileClip |
|
|
|
class VideoRAGSystem: |
|
def __init__(self): |
|
self.logger = self.setup_logger() |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.logger.info(f"Using device: {self.device}") |
|
|
|
|
|
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device) |
|
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") |
|
self.blip_model = Blip2ForConditionalGeneration.from_pretrained( |
|
"Salesforce/blip2-opt-2.7b", |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
).to(self.device) |
|
|
|
|
|
self.frame_index = None |
|
self.frame_data = [] |
|
self.target_size = (224, 224) |
|
|
|
|
|
self.temp_dir = tempfile.mkdtemp() |
|
self.frames_dir = os.path.join(self.temp_dir, "frames") |
|
os.makedirs(self.frames_dir, exist_ok=True) |
|
|
|
def setup_logger(self) -> logging.Logger: |
|
logger = logging.getLogger('VideoRAGSystem') |
|
if logger.handlers: |
|
logger.handlers.clear() |
|
logger.setLevel(logging.INFO) |
|
handler = logging.StreamHandler() |
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') |
|
handler.setFormatter(formatter) |
|
logger.addHandler(handler) |
|
return logger |
|
|
|
def split_video(self, video_path: str, timestamp_ms: int, context_seconds: int = 3) -> str: |
|
"""Extract a clip around the specified timestamp""" |
|
timestamp_sec = timestamp_ms / 1000 |
|
output_path = os.path.join(self.temp_dir, "clip.mp4") |
|
|
|
with VideoFileClip(video_path) as video: |
|
duration = video.duration |
|
start_time = max(timestamp_sec - context_seconds, 0) |
|
end_time = min(timestamp_sec + context_seconds, duration) |
|
clip = video.subclip(start_time, end_time) |
|
clip.write_videofile(output_path, audio_codec='aac') |
|
|
|
return output_path |
|
|
|
@torch.no_grad() |
|
def analyze_frame(self, image: Image.Image) -> Dict: |
|
"""Comprehensive frame analysis""" |
|
try: |
|
|
|
inputs = self.blip_processor(image, return_tensors="pt").to(self.device) |
|
if self.device.type == "cuda": |
|
inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()} |
|
caption = self.blip_model.generate(**inputs, max_length=50) |
|
caption_text = self.blip_processor.decode(caption[0], skip_special_tokens=True) |
|
|
|
|
|
clip_inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device) |
|
if self.device.type == "cuda": |
|
clip_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in clip_inputs.items()} |
|
features = self.clip_model.get_image_features(**clip_inputs) |
|
|
|
return { |
|
"caption": caption_text, |
|
"features": features.cpu().numpy() |
|
} |
|
except Exception as e: |
|
self.logger.error(f"Frame analysis error: {str(e)}") |
|
return None |
|
|
|
def extract_keyframes(self, video_path: str, max_frames: int = 15) -> List[Dict]: |
|
"""Extract and analyze key frames""" |
|
cap = cv2.VideoCapture(video_path) |
|
frames_info = [] |
|
frame_count = 0 |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
interval = max(1, total_frames // max_frames) |
|
|
|
with tqdm(total=max_frames, desc="Analyzing frames") as pbar: |
|
while len(frames_info) < max_frames and cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
if frame_count % interval == 0: |
|
|
|
frame_path = os.path.join(self.frames_dir, f"frame_{frame_count}.jpg") |
|
cv2.imwrite(frame_path, frame) |
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
image = Image.fromarray(frame_rgb).resize(self.target_size, Image.LANCZOS) |
|
analysis = self.analyze_frame(image) |
|
|
|
if analysis is not None: |
|
frames_info.append({ |
|
"frame_number": frame_count, |
|
"timestamp": frame_count / cap.get(cv2.CAP_PROP_FPS), |
|
"path": frame_path, |
|
"caption": analysis["caption"], |
|
"features": analysis["features"] |
|
}) |
|
pbar.update(1) |
|
|
|
frame_count += 1 |
|
|
|
cap.release() |
|
return frames_info |
|
|
|
def process_video(self, video_path: str): |
|
"""Process video and build search index""" |
|
self.logger.info(f"Processing video: {video_path}") |
|
|
|
try: |
|
|
|
frames_info = self.extract_keyframes(video_path) |
|
self.frame_data = frames_info |
|
|
|
|
|
if frames_info: |
|
features = np.vstack([frame["features"] for frame in frames_info]) |
|
self.frame_index = faiss.IndexFlatL2(features.shape[1]) |
|
self.frame_index.add(features) |
|
|
|
self.logger.info(f"Processed {len(frames_info)} frames successfully") |
|
return True |
|
|
|
except Exception as e: |
|
self.logger.error(f"Video processing error: {str(e)}") |
|
return False |
|
|
|
@torch.no_grad() |
|
def search_frames(self, query: str, k: int = 4) -> List[Dict]: |
|
"""Search for relevant frames based on the query""" |
|
try: |
|
|
|
inputs = self.clip_processor(text=[query], return_tensors="pt").to(self.device) |
|
if self.device.type == "cuda": |
|
inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()} |
|
query_features = self.clip_model.get_text_features(**inputs) |
|
|
|
|
|
distances, indices = self.frame_index.search( |
|
query_features.cpu().numpy(), |
|
k |
|
) |
|
|
|
|
|
results = [] |
|
for distance, idx in zip(distances[0], indices[0]): |
|
frame_info = self.frame_data[idx].copy() |
|
frame_info["relevance"] = float(1 / (1 + distance)) |
|
results.append(frame_info) |
|
|
|
return results |
|
|
|
except Exception as e: |
|
self.logger.error(f"Search error: {str(e)}") |
|
return [] |
|
|
|
class VideoQAInterface: |
|
def __init__(self): |
|
self.rag_system = VideoRAGSystem() |
|
self.current_video = None |
|
self.processed = False |
|
|
|
def process_video(self, video_file): |
|
"""Handle video upload and processing""" |
|
try: |
|
if video_file is None: |
|
return "Please upload a video first.", gr.Progress(0) |
|
|
|
self.current_video = video_file.name |
|
success = self.rag_system.process_video(self.current_video) |
|
|
|
if success: |
|
self.processed = True |
|
return "Video processed successfully! You can now ask questions.", gr.Progress(100) |
|
else: |
|
return "Error processing video. Please try again.", gr.Progress(0) |
|
|
|
except Exception as e: |
|
self.processed = False |
|
return f"Error: {str(e)}", gr.Progress(0) |
|
|
|
def answer_question(self, query): |
|
"""Handle question answering""" |
|
if not self.processed: |
|
return None, "Please process a video first." |
|
|
|
try: |
|
|
|
results = self.rag_system.search_frames(query) |
|
|
|
if not results: |
|
return None, "No relevant frames found." |
|
|
|
|
|
frames = [] |
|
descriptions = [] |
|
|
|
for result in results: |
|
|
|
frame = Image.open(result["path"]) |
|
frames.append(frame) |
|
|
|
|
|
desc = f"Timestamp: {result['timestamp']:.2f}s\n" |
|
desc += f"Scene Description: {result['caption']}\n" |
|
desc += f"Relevance Score: {result['relevance']:.2f}" |
|
descriptions.append(desc) |
|
|
|
|
|
combined_desc = "\n\nFrame Analysis:\n\n" |
|
for i, desc in enumerate(descriptions, 1): |
|
combined_desc += f"Frame {i}:\n{desc}\n\n" |
|
|
|
return frames, combined_desc |
|
|
|
except Exception as e: |
|
return None, f"Error: {str(e)}" |
|
|
|
def create_interface(self): |
|
"""Create Gradio interface""" |
|
with gr.Blocks(title="Advanced Video Question Answering") as interface: |
|
gr.Markdown("# Advanced Video Question Answering") |
|
gr.Markdown("Upload a video and ask questions about any aspect of its content!") |
|
|
|
with gr.Row(): |
|
video_input = gr.File( |
|
label="Upload Video", |
|
file_types=["video"], |
|
) |
|
process_button = gr.Button("Process Video") |
|
|
|
status_output = gr.Textbox( |
|
label="Status", |
|
interactive=False |
|
) |
|
|
|
with gr.Row(): |
|
query_input = gr.Textbox( |
|
label="Ask about the video", |
|
placeholder="What's happening in the video?" |
|
) |
|
query_button = gr.Button("Search") |
|
|
|
gallery = gr.Gallery( |
|
label="Retrieved Frames", |
|
show_label=True, |
|
columns=[2], |
|
rows=[2], |
|
height="auto" |
|
) |
|
|
|
descriptions = gr.Textbox( |
|
label="Scene Analysis", |
|
interactive=False, |
|
lines=10 |
|
) |
|
|
|
process_button.click( |
|
fn=self.process_video, |
|
inputs=[video_input], |
|
outputs=[status_output] |
|
) |
|
|
|
query_button.click( |
|
fn=self.answer_question, |
|
inputs=[query_input], |
|
outputs=[gallery, descriptions] |
|
) |
|
|
|
return interface |
|
|
|
|
|
app = VideoQAInterface() |
|
interface = app.create_interface() |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |