capradeepgujaran commited on
Commit
d8bea64
·
verified ·
1 Parent(s): ad632e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -22
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import cv2
2
  import numpy as np
3
- from transformers import CLIPProcessor, CLIPModel
4
  import torch
5
  from PIL import Image
6
  import faiss
@@ -13,13 +13,21 @@ import os
13
  import shutil
14
 
15
  class VideoRAGTool:
16
- def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
 
17
  """
18
- Initialize the Video RAG Tool with CLIP model for frame analysis.
19
  """
20
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- self.model = CLIPModel.from_pretrained(model_name).to(self.device)
22
- self.processor = CLIPProcessor.from_pretrained(model_name)
 
 
 
 
 
 
 
23
  self.frame_index = None
24
  self.frame_data = []
25
  self.logger = self._setup_logger()
@@ -33,6 +41,13 @@ class VideoRAGTool:
33
  logger.addHandler(handler)
34
  return logger
35
 
 
 
 
 
 
 
 
36
  def process_video(self, video_path: str, frame_interval: int = 30) -> None:
37
  """Process video file and extract features from frames."""
38
  self.logger.info(f"Processing video: {video_path}")
@@ -49,12 +64,17 @@ class VideoRAGTool:
49
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
50
  image = Image.fromarray(frame_rgb)
51
 
52
- inputs = self.processor(images=image, return_tensors="pt").to(self.device)
53
- image_features = self.model.get_image_features(**inputs)
 
 
 
 
54
 
55
  self.frame_data.append({
56
  'frame_number': frame_count,
57
- 'timestamp': frame_count / cap.get(cv2.CAP_PROP_FPS)
 
58
  })
59
  features_list.append(image_features.cpu().detach().numpy())
60
 
@@ -75,8 +95,8 @@ class VideoRAGTool:
75
  """Query the video using natural language and return relevant frames."""
76
  self.logger.info(f"Processing query: {query_text}")
77
 
78
- inputs = self.processor(text=[query_text], return_tensors="pt").to(self.device)
79
- text_features = self.model.get_text_features(**inputs)
80
 
81
  distances, indices = self.frame_index.search(
82
  text_features.cpu().detach().numpy(),
@@ -109,10 +129,7 @@ class VideoRAGApp:
109
  if video_file is None:
110
  return "Please upload a video first."
111
 
112
- # video_file is now a file path provided by Gradio
113
  video_path = video_file.name
114
-
115
- # Create a copy in our temp directory
116
  temp_video_path = os.path.join(self.temp_dir, "current_video.mp4")
117
  shutil.copy2(video_path, temp_video_path)
118
 
@@ -135,7 +152,7 @@ class VideoRAGApp:
135
  results = self.rag_tool.query_video(query_text, k=4)
136
 
137
  frames = []
138
- captions = []
139
 
140
  cap = cv2.VideoCapture(self.current_video_path)
141
 
@@ -148,13 +165,19 @@ class VideoRAGApp:
148
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
149
  frames.append(Image.fromarray(frame_rgb))
150
 
151
- caption = f"Timestamp: {result['timestamp']:.2f}s\n"
152
- caption += f"Relevance: {result['relevance_score']:.2f}"
153
- captions.append(caption)
 
154
 
155
  cap.release()
156
 
157
- return frames, "\n\n".join(captions)
 
 
 
 
 
158
 
159
  except Exception as e:
160
  return None, f"Error querying video: {str(e)}"
@@ -194,9 +217,10 @@ class VideoRAGApp:
194
  height="auto"
195
  )
196
 
197
- captions = gr.Textbox(
198
- label="Frame Details",
199
- interactive=False
 
200
  )
201
 
202
  process_button.click(
@@ -208,7 +232,7 @@ class VideoRAGApp:
208
  query_button.click(
209
  fn=self.query_video,
210
  inputs=[query_input],
211
- outputs=[gallery, captions]
212
  )
213
 
214
  return interface
 
1
  import cv2
2
  import numpy as np
3
+ from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration
4
  import torch
5
  from PIL import Image
6
  import faiss
 
13
  import shutil
14
 
15
  class VideoRAGTool:
16
+ def __init__(self, clip_model_name: str = "openai/clip-vit-base-patch32",
17
+ blip_model_name: str = "Salesforce/blip-image-captioning-base"):
18
  """
19
+ Initialize the Video RAG Tool with CLIP and BLIP models for frame analysis and captioning.
20
  """
21
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ # Initialize CLIP for frame retrieval
24
+ self.clip_model = CLIPModel.from_pretrained(clip_model_name).to(self.device)
25
+ self.clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
26
+
27
+ # Initialize BLIP for image captioning
28
+ self.blip_processor = BlipProcessor.from_pretrained(blip_model_name)
29
+ self.blip_model = BlipForConditionalGeneration.from_pretrained(blip_model_name).to(self.device)
30
+
31
  self.frame_index = None
32
  self.frame_data = []
33
  self.logger = self._setup_logger()
 
41
  logger.addHandler(handler)
42
  return logger
43
 
44
+ def generate_caption(self, image: Image.Image) -> str:
45
+ """Generate a description for the given image using BLIP."""
46
+ inputs = self.blip_processor(image, return_tensors="pt").to(self.device)
47
+ out = self.blip_model.generate(**inputs)
48
+ caption = self.blip_processor.decode(out[0], skip_special_tokens=True)
49
+ return caption
50
+
51
  def process_video(self, video_path: str, frame_interval: int = 30) -> None:
52
  """Process video file and extract features from frames."""
53
  self.logger.info(f"Processing video: {video_path}")
 
64
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65
  image = Image.fromarray(frame_rgb)
66
 
67
+ # Generate caption for the frame
68
+ caption = self.generate_caption(image)
69
+
70
+ # Process frame with CLIP
71
+ inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device)
72
+ image_features = self.clip_model.get_image_features(**inputs)
73
 
74
  self.frame_data.append({
75
  'frame_number': frame_count,
76
+ 'timestamp': frame_count / cap.get(cv2.CAP_PROP_FPS),
77
+ 'caption': caption
78
  })
79
  features_list.append(image_features.cpu().detach().numpy())
80
 
 
95
  """Query the video using natural language and return relevant frames."""
96
  self.logger.info(f"Processing query: {query_text}")
97
 
98
+ inputs = self.clip_processor(text=[query_text], return_tensors="pt").to(self.device)
99
+ text_features = self.clip_model.get_text_features(**inputs)
100
 
101
  distances, indices = self.frame_index.search(
102
  text_features.cpu().detach().numpy(),
 
129
  if video_file is None:
130
  return "Please upload a video first."
131
 
 
132
  video_path = video_file.name
 
 
133
  temp_video_path = os.path.join(self.temp_dir, "current_video.mp4")
134
  shutil.copy2(video_path, temp_video_path)
135
 
 
152
  results = self.rag_tool.query_video(query_text, k=4)
153
 
154
  frames = []
155
+ descriptions = []
156
 
157
  cap = cv2.VideoCapture(self.current_video_path)
158
 
 
165
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
166
  frames.append(Image.fromarray(frame_rgb))
167
 
168
+ description = f"Timestamp: {result['timestamp']:.2f}s\n"
169
+ description += f"Scene Description: {result['caption']}\n"
170
+ description += f"Relevance Score: {result['relevance_score']:.2f}"
171
+ descriptions.append(description)
172
 
173
  cap.release()
174
 
175
+ # Combine all descriptions with frame numbers
176
+ combined_description = "\n\nFrame Analysis:\n\n"
177
+ for i, desc in enumerate(descriptions, 1):
178
+ combined_description += f"Frame {i}:\n{desc}\n\n"
179
+
180
+ return frames, combined_description
181
 
182
  except Exception as e:
183
  return None, f"Error querying video: {str(e)}"
 
217
  height="auto"
218
  )
219
 
220
+ descriptions = gr.Textbox(
221
+ label="Scene Descriptions",
222
+ interactive=False,
223
+ lines=10
224
  )
225
 
226
  process_button.click(
 
232
  query_button.click(
233
  fn=self.query_video,
234
  inputs=[query_input],
235
+ outputs=[gallery, descriptions]
236
  )
237
 
238
  return interface