|
import gradio as gr |
|
from video_rag_tool import VideoRAGTool |
|
import tempfile |
|
import os |
|
from PIL import Image |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
|
|
class VideoRAGApp: |
|
def __init__(self): |
|
self.rag_tool = VideoRAGTool() |
|
self.current_video_path = None |
|
self.processed = False |
|
|
|
def process_video(self, video_file): |
|
"""Process uploaded video and return status message""" |
|
if video_file is None: |
|
return "Please upload a video first." |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
temp_path = os.path.join(temp_dir, "uploaded_video.mp4") |
|
|
|
with open(temp_path, "wb") as f: |
|
f.write(video_file) |
|
|
|
self.current_video_path = temp_path |
|
|
|
try: |
|
self.rag_tool.process_video(temp_path) |
|
self.processed = True |
|
return "Video processed successfully! You can now ask questions about the video." |
|
except Exception as e: |
|
return f"Error processing video: {str(e)}" |
|
|
|
def query_video(self, query_text): |
|
"""Query the video and return relevant frames with descriptions""" |
|
if not self.processed: |
|
return "Please process a video first." |
|
|
|
try: |
|
results = self.rag_tool.query_video(query_text, k=4) |
|
|
|
|
|
frames = [] |
|
captions = [] |
|
|
|
cap = cv2.VideoCapture(self.current_video_path) |
|
|
|
for result in results: |
|
frame_number = result['frame_number'] |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) |
|
ret, frame = cap.read() |
|
|
|
if ret: |
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frames.append(Image.fromarray(frame_rgb)) |
|
|
|
caption = f"Timestamp: {result['timestamp']:.2f}s\n" |
|
caption += f"Relevance: {result['relevance_score']:.2f}" |
|
captions.append(caption) |
|
|
|
cap.release() |
|
|
|
return frames, captions |
|
|
|
except Exception as e: |
|
return f"Error querying video: {str(e)}" |
|
|
|
def create_interface(self): |
|
"""Create and return Gradio interface""" |
|
with gr.Blocks(title="Video Chat RAG") as interface: |
|
gr.Markdown("# Video Chat RAG") |
|
gr.Markdown("Upload a video and ask questions about its content!") |
|
|
|
with gr.Row(): |
|
video_input = gr.File( |
|
label="Upload Video", |
|
file_types=["video"], |
|
) |
|
process_button = gr.Button("Process Video") |
|
|
|
status_output = gr.Textbox( |
|
label="Status", |
|
interactive=False |
|
) |
|
|
|
with gr.Row(): |
|
query_input = gr.Textbox( |
|
label="Ask about the video", |
|
placeholder="What's happening in the video?" |
|
) |
|
query_button = gr.Button("Search") |
|
|
|
with gr.Row(): |
|
gallery = gr.Gallery( |
|
label="Retrieved Frames", |
|
show_label=True, |
|
elem_id="gallery", |
|
columns=[2], |
|
rows=[2], |
|
height="auto" |
|
) |
|
|
|
captions = gr.Textbox( |
|
label="Frame Details", |
|
interactive=False |
|
) |
|
|
|
|
|
process_button.click( |
|
fn=self.process_video, |
|
inputs=[video_input], |
|
outputs=[status_output] |
|
) |
|
|
|
query_button.click( |
|
fn=self.query_video, |
|
inputs=[query_input], |
|
outputs=[gallery, captions] |
|
) |
|
|
|
return interface |
|
|
|
|
|
app = VideoRAGApp() |
|
interface = app.create_interface() |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |