capradeepgujaran commited on
Commit
9d6df4b
·
verified ·
1 Parent(s): 5a09cf2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -13
app.py CHANGED
@@ -1,11 +1,91 @@
1
- import gradio as gr
2
- from video_rag_tool import VideoRAGTool
3
- import tempfile
4
- import os
5
- from PIL import Image
6
  import cv2
7
  import numpy as np
 
8
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class VideoRAGApp:
11
  def __init__(self):
@@ -18,7 +98,6 @@ class VideoRAGApp:
18
  if video_file is None:
19
  return "Please upload a video first."
20
 
21
- # Save uploaded video to temporary file
22
  temp_dir = tempfile.mkdtemp()
23
  temp_path = os.path.join(temp_dir, "uploaded_video.mp4")
24
 
@@ -37,12 +116,11 @@ class VideoRAGApp:
37
  def query_video(self, query_text):
38
  """Query the video and return relevant frames with descriptions"""
39
  if not self.processed:
40
- return "Please process a video first."
41
 
42
  try:
43
  results = self.rag_tool.query_video(query_text, k=4)
44
 
45
- # Extract frames for display
46
  frames = []
47
  captions = []
48
 
@@ -63,10 +141,10 @@ class VideoRAGApp:
63
 
64
  cap.release()
65
 
66
- return frames, captions
67
 
68
  except Exception as e:
69
- return f"Error querying video: {str(e)}"
70
 
71
  def create_interface(self):
72
  """Create and return Gradio interface"""
@@ -108,7 +186,6 @@ class VideoRAGApp:
108
  interactive=False
109
  )
110
 
111
- # Set up event handlers
112
  process_button.click(
113
  fn=self.process_video,
114
  inputs=[video_input],
@@ -123,10 +200,10 @@ class VideoRAGApp:
123
 
124
  return interface
125
 
126
- # For Hugging Face Spaces deployment
127
  app = VideoRAGApp()
128
  interface = app.create_interface()
129
 
130
- # Launch the app (for local testing)
131
  if __name__ == "__main__":
132
  interface.launch()
 
 
 
 
 
 
1
  import cv2
2
  import numpy as np
3
+ from transformers import CLIPProcessor, CLIPModel
4
  import torch
5
+ from PIL import Image
6
+ import faiss
7
+ import pickle
8
+ from typing import List, Dict, Tuple
9
+ import logging
10
+ import gradio as gr
11
+ import tempfile
12
+ import os
13
+
14
+ class VideoRAGTool:
15
+ def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
16
+ """
17
+ Initialize the Video RAG Tool with CLIP model for frame analysis.
18
+ """
19
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ self.model = CLIPModel.from_pretrained(model_name).to(self.device)
21
+ self.processor = CLIPProcessor.from_pretrained(model_name)
22
+ self.frame_index = None
23
+ self.frame_data = []
24
+ self.logger = self._setup_logger()
25
+
26
+ def _setup_logger(self) -> logging.Logger:
27
+ logger = logging.getLogger('VideoRAGTool')
28
+ logger.setLevel(logging.INFO)
29
+ handler = logging.StreamHandler()
30
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
31
+ handler.setFormatter(formatter)
32
+ logger.addHandler(handler)
33
+ return logger
34
+
35
+ def process_video(self, video_path: str, frame_interval: int = 30) -> None:
36
+ """Process video file and extract features from frames."""
37
+ self.logger.info(f"Processing video: {video_path}")
38
+ cap = cv2.VideoCapture(video_path)
39
+ frame_count = 0
40
+ features_list = []
41
+
42
+ while cap.isOpened():
43
+ ret, frame = cap.read()
44
+ if not ret:
45
+ break
46
+
47
+ if frame_count % frame_interval == 0:
48
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
49
+ image = Image.fromarray(frame_rgb)
50
+
51
+ inputs = self.processor(images=image, return_tensors="pt").to(self.device)
52
+ image_features = self.model.get_image_features(**inputs)
53
+
54
+ self.frame_data.append({
55
+ 'frame_number': frame_count,
56
+ 'timestamp': frame_count / cap.get(cv2.CAP_PROP_FPS)
57
+ })
58
+ features_list.append(image_features.cpu().detach().numpy())
59
+
60
+ frame_count += 1
61
+
62
+ cap.release()
63
+
64
+ features_array = np.vstack(features_list)
65
+ self.frame_index = faiss.IndexFlatL2(features_array.shape[1])
66
+ self.frame_index.add(features_array)
67
+
68
+ self.logger.info(f"Processed {len(self.frame_data)} frames from video")
69
+
70
+ def query_video(self, query_text: str, k: int = 5) -> List[Dict]:
71
+ """Query the video using natural language and return relevant frames."""
72
+ self.logger.info(f"Processing query: {query_text}")
73
+
74
+ inputs = self.processor(text=[query_text], return_tensors="pt").to(self.device)
75
+ text_features = self.model.get_text_features(**inputs)
76
+
77
+ distances, indices = self.frame_index.search(
78
+ text_features.cpu().detach().numpy(),
79
+ k
80
+ )
81
+
82
+ results = []
83
+ for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
84
+ frame_info = self.frame_data[idx].copy()
85
+ frame_info['relevance_score'] = float(1 / (1 + distance))
86
+ results.append(frame_info)
87
+
88
+ return results
89
 
90
  class VideoRAGApp:
91
  def __init__(self):
 
98
  if video_file is None:
99
  return "Please upload a video first."
100
 
 
101
  temp_dir = tempfile.mkdtemp()
102
  temp_path = os.path.join(temp_dir, "uploaded_video.mp4")
103
 
 
116
  def query_video(self, query_text):
117
  """Query the video and return relevant frames with descriptions"""
118
  if not self.processed:
119
+ return None, "Please process a video first."
120
 
121
  try:
122
  results = self.rag_tool.query_video(query_text, k=4)
123
 
 
124
  frames = []
125
  captions = []
126
 
 
141
 
142
  cap.release()
143
 
144
+ return frames, "\n\n".join(captions)
145
 
146
  except Exception as e:
147
+ return None, f"Error querying video: {str(e)}"
148
 
149
  def create_interface(self):
150
  """Create and return Gradio interface"""
 
186
  interactive=False
187
  )
188
 
 
189
  process_button.click(
190
  fn=self.process_video,
191
  inputs=[video_input],
 
200
 
201
  return interface
202
 
203
+ # Initialize and create the interface
204
  app = VideoRAGApp()
205
  interface = app.create_interface()
206
 
207
+ # Launch the app
208
  if __name__ == "__main__":
209
  interface.launch()