aiqcamp commited on
Commit
e4f7d0b
Β·
verified Β·
1 Parent(s): 6feb9f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -39
app.py CHANGED
@@ -6,14 +6,18 @@ import os
6
  import tempfile
7
  import spaces
8
  import gradio as gr
9
-
10
  import subprocess
11
  import sys
 
 
 
 
 
 
12
 
13
  def install_flash_attn_wheel():
14
  flash_attn_wheel_url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
15
  try:
16
- # Call pip to install the wheel file
17
  subprocess.check_call([sys.executable, "-m", "pip", "install", flash_attn_wheel_url])
18
  print("Wheel installed successfully!")
19
  except subprocess.CalledProcessError as e:
@@ -21,7 +25,6 @@ def install_flash_attn_wheel():
21
 
22
  install_flash_attn_wheel()
23
 
24
- import cv2
25
  try:
26
  from mmengine.visualization import Visualizer
27
  except ImportError:
@@ -43,25 +46,75 @@ tokenizer = AutoTokenizer.from_pretrained(
43
  trust_remote_code = True,
44
  )
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  from third_parts import VideoReader
47
  def read_video(video_path, video_interval):
48
  vid_frames = VideoReader(video_path)[::video_interval]
49
 
50
  temp_dir = tempfile.mkdtemp()
51
  os.makedirs(temp_dir, exist_ok=True)
52
- image_paths = [] # List to store paths of saved images
53
 
54
  for frame_idx in range(len(vid_frames)):
55
  frame_image = vid_frames[frame_idx]
56
- frame_image = frame_image[..., ::-1] # BGR (opencv system) to RGB (numpy system)
57
  frame_image = Image.fromarray(frame_image)
58
  vid_frames[frame_idx] = frame_image
59
-
60
- # Save the frame as a .jpg file in the temporary folder
61
  image_path = os.path.join(temp_dir, f"frame_{frame_idx:04d}.jpg")
62
  frame_image.save(image_path, format="JPEG")
63
-
64
- # Append the image path to the list
65
  image_paths.append(image_path)
66
  return vid_frames, image_paths
67
 
@@ -71,17 +124,10 @@ def visualize(pred_mask, image_path, work_dir):
71
  visualizer.set_image(img)
72
  visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4)
73
  visual_result = visualizer.get_image()
74
-
75
  output_path = os.path.join(work_dir, os.path.basename(image_path))
76
  cv2.imwrite(output_path, visual_result)
77
  return output_path
78
 
79
-
80
-
81
- # μ½”λ“œ 상단에 import μΆ”κ°€
82
- from deep_translator import GoogleTranslator
83
-
84
- # λ²ˆμ—­ ν•¨μˆ˜ μˆ˜μ •
85
  def translate_to_korean(text):
86
  try:
87
  translator = GoogleTranslator(source='en', target='ko')
@@ -92,7 +138,6 @@ def translate_to_korean(text):
92
 
93
  @spaces.GPU
94
  def image_vision(image_input_path, prompt):
95
- # ν•œκΈ€ μž…λ ₯ 확인
96
  is_korean = any(ord('κ°€') <= ord(char) <= ord('힣') for char in prompt)
97
 
98
  image_path = image_input_path
@@ -109,9 +154,7 @@ def image_vision(image_input_path, prompt):
109
  print(return_dict)
110
  answer = return_dict["prediction"]
111
 
112
- # ν•œκΈ€ ν”„λ‘¬ν”„νŠΈμΈ 경우 응닡을 ν•œκΈ€λ‘œ λ²ˆμ—­
113
  if is_korean:
114
- # [SEG]λŠ” λ³΄μ‘΄ν•˜λ©΄μ„œ λ‚˜λ¨Έμ§€ ν…μŠ€νŠΈλ§Œ λ²ˆμ—­
115
  if '[SEG]' in answer:
116
  parts = answer.split('[SEG]')
117
  translated_parts = [translate_to_korean(part.strip()) for part in parts]
@@ -133,7 +176,6 @@ def image_vision(image_input_path, prompt):
133
 
134
  @spaces.GPU(duration=80)
135
  def video_vision(video_input_path, prompt, video_interval):
136
- # ν•œκΈ€ μž…λ ₯ 확인
137
  is_korean = any(ord('κ°€') <= ord(char) <= ord('힣') for char in prompt)
138
 
139
  cap = cv2.VideoCapture(video_input_path)
@@ -151,7 +193,6 @@ def video_vision(video_input_path, prompt, video_interval):
151
  prediction = result['prediction']
152
  print(prediction)
153
 
154
- # ν•œκΈ€ ν”„λ‘¬ν”„νŠΈμΈ 경우 응닡을 ν•œκΈ€λ‘œ λ²ˆμ—­
155
  if is_korean:
156
  if '[SEG]' in prediction:
157
  parts = prediction.split('[SEG]')
@@ -185,31 +226,38 @@ def video_vision(video_input_path, prompt, video_interval):
185
  print(f"Video created successfully at {output_video}")
186
 
187
  return prediction, output_video
188
-
189
  else:
190
  return prediction, None
191
 
192
- # Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
 
194
  with gr.Blocks(analytics_enabled=False) as demo:
195
  with gr.Column():
196
  gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos")
197
- gr.HTML("""
198
- <div style="display:flex;column-gap:4px;">
199
- <a href="https://github.com/magic-research/Sa2VA">
200
- <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
201
- </a>
202
- <a href="https://arxiv.org/abs/2501.04001">
203
- <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
204
- </a>
205
- <a href="https://huggingface.co/spaces/fffiloni/Sa2VA-simple-demo?duplicate=true">
206
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
207
- </a>
208
- <a href="https://huggingface.co/fffiloni">
209
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
210
- </a>
211
- </div>
212
- """)
213
  with gr.Tab("Single Image"):
214
  with gr.Row():
215
  with gr.Column():
@@ -226,6 +274,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
226
  inputs = [image_input, instruction],
227
  outputs = [output_res, output_image]
228
  )
 
229
  with gr.Tab("Video"):
230
  with gr.Row():
231
  with gr.Column():
@@ -243,5 +292,34 @@ with gr.Blocks(analytics_enabled=False) as demo:
243
  inputs = [video_input, vid_instruction, frame_interval],
244
  outputs = [vid_output_res, output_video]
245
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
  demo.queue().launch(show_api=False, show_error=True)
 
6
  import tempfile
7
  import spaces
8
  import gradio as gr
 
9
  import subprocess
10
  import sys
11
+ import cv2
12
+ import threading
13
+ import queue
14
+ import time
15
+ from collections import deque
16
+ from deep_translator import GoogleTranslator
17
 
18
  def install_flash_attn_wheel():
19
  flash_attn_wheel_url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
20
  try:
 
21
  subprocess.check_call([sys.executable, "-m", "pip", "install", flash_attn_wheel_url])
22
  print("Wheel installed successfully!")
23
  except subprocess.CalledProcessError as e:
 
25
 
26
  install_flash_attn_wheel()
27
 
 
28
  try:
29
  from mmengine.visualization import Visualizer
30
  except ImportError:
 
46
  trust_remote_code = True,
47
  )
48
 
49
+ class WebcamProcessor:
50
+ def __init__(self, model, tokenizer, fps_target=15, buffer_size=5):
51
+ self.model = model
52
+ self.tokenizer = tokenizer
53
+ self.fps_target = fps_target
54
+ self.frame_interval = 1.0 / fps_target
55
+ self.buffer_size = buffer_size
56
+ self.frame_buffer = deque(maxlen=buffer_size)
57
+ self.result_queue = queue.Queue()
58
+ self.is_running = False
59
+ self.last_process_time = 0
60
+
61
+ def start(self):
62
+ self.is_running = True
63
+ self.capture = cv2.VideoCapture(0)
64
+ self.capture_thread = threading.Thread(target=self._capture_loop)
65
+ self.process_thread = threading.Thread(target=self._process_loop)
66
+ self.capture_thread.start()
67
+ self.process_thread.start()
68
+
69
+ def stop(self):
70
+ self.is_running = False
71
+ if hasattr(self, 'capture_thread'):
72
+ self.capture_thread.join()
73
+ self.process_thread.join()
74
+ self.capture.release()
75
+
76
+ def _capture_loop(self):
77
+ while self.is_running:
78
+ ret, frame = self.capture.read()
79
+ if ret:
80
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
81
+ frame = cv2.resize(frame, (640, 480))
82
+ current_time = time.time()
83
+ if current_time - self.last_process_time >= self.frame_interval:
84
+ self.frame_buffer.append(frame)
85
+ self.last_process_time = current_time
86
+
87
+ def _process_loop(self):
88
+ while self.is_running:
89
+ if len(self.frame_buffer) >= self.buffer_size:
90
+ frames = list(self.frame_buffer)
91
+ try:
92
+ result = self.model.predict_forward(
93
+ video=frames,
94
+ text="<image>Describe what you see",
95
+ tokenizer=self.tokenizer
96
+ )
97
+ self.result_queue.put(result)
98
+ except Exception as e:
99
+ print(f"Processing error: {e}")
100
+ self.frame_buffer.clear()
101
+ time.sleep(0.1)
102
+
103
  from third_parts import VideoReader
104
  def read_video(video_path, video_interval):
105
  vid_frames = VideoReader(video_path)[::video_interval]
106
 
107
  temp_dir = tempfile.mkdtemp()
108
  os.makedirs(temp_dir, exist_ok=True)
109
+ image_paths = []
110
 
111
  for frame_idx in range(len(vid_frames)):
112
  frame_image = vid_frames[frame_idx]
113
+ frame_image = frame_image[..., ::-1]
114
  frame_image = Image.fromarray(frame_image)
115
  vid_frames[frame_idx] = frame_image
 
 
116
  image_path = os.path.join(temp_dir, f"frame_{frame_idx:04d}.jpg")
117
  frame_image.save(image_path, format="JPEG")
 
 
118
  image_paths.append(image_path)
119
  return vid_frames, image_paths
120
 
 
124
  visualizer.set_image(img)
125
  visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4)
126
  visual_result = visualizer.get_image()
 
127
  output_path = os.path.join(work_dir, os.path.basename(image_path))
128
  cv2.imwrite(output_path, visual_result)
129
  return output_path
130
 
 
 
 
 
 
 
131
  def translate_to_korean(text):
132
  try:
133
  translator = GoogleTranslator(source='en', target='ko')
 
138
 
139
  @spaces.GPU
140
  def image_vision(image_input_path, prompt):
 
141
  is_korean = any(ord('κ°€') <= ord(char) <= ord('힣') for char in prompt)
142
 
143
  image_path = image_input_path
 
154
  print(return_dict)
155
  answer = return_dict["prediction"]
156
 
 
157
  if is_korean:
 
158
  if '[SEG]' in answer:
159
  parts = answer.split('[SEG]')
160
  translated_parts = [translate_to_korean(part.strip()) for part in parts]
 
176
 
177
  @spaces.GPU(duration=80)
178
  def video_vision(video_input_path, prompt, video_interval):
 
179
  is_korean = any(ord('κ°€') <= ord(char) <= ord('힣') for char in prompt)
180
 
181
  cap = cv2.VideoCapture(video_input_path)
 
193
  prediction = result['prediction']
194
  print(prediction)
195
 
 
196
  if is_korean:
197
  if '[SEG]' in prediction:
198
  parts = prediction.split('[SEG]')
 
226
  print(f"Video created successfully at {output_video}")
227
 
228
  return prediction, output_video
 
229
  else:
230
  return prediction, None
231
 
232
+ @spaces.GPU
233
+ def webcam_vision(prompt):
234
+ is_korean = any(ord('κ°€') <= ord(char) <= ord('힣') for char in prompt)
235
+
236
+ if not hasattr(webcam_vision, 'processor'):
237
+ webcam_vision.processor = WebcamProcessor(model, tokenizer)
238
+
239
+ if not webcam_vision.processor.is_running:
240
+ webcam_vision.processor.start()
241
+
242
+ try:
243
+ result = webcam_vision.processor.result_queue.get(timeout=5)
244
+ prediction = result['prediction']
245
+
246
+ if is_korean:
247
+ prediction = translate_to_korean(prediction)
248
+
249
+ return prediction
250
+ except queue.Empty:
251
+ return "No results available yet"
252
+ except Exception as e:
253
+ return f"Error: {str(e)}"
254
 
255
+ # Gradio UI
256
  with gr.Blocks(analytics_enabled=False) as demo:
257
  with gr.Column():
258
  gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos")
259
+
260
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  with gr.Tab("Single Image"):
262
  with gr.Row():
263
  with gr.Column():
 
274
  inputs = [image_input, instruction],
275
  outputs = [output_res, output_image]
276
  )
277
+
278
  with gr.Tab("Video"):
279
  with gr.Row():
280
  with gr.Column():
 
292
  inputs = [video_input, vid_instruction, frame_interval],
293
  outputs = [vid_output_res, output_video]
294
  )
295
+
296
+ with gr.Tab("Webcam"):
297
+ with gr.Row():
298
+ with gr.Column():
299
+ webcam_input = gr.Image(source="webcam", streaming=True)
300
+ with gr.Row():
301
+ webcam_instruction = gr.Textbox(
302
+ label="Instruction",
303
+ placeholder="Enter instruction here...",
304
+ scale=4
305
+ )
306
+ start_button = gr.Button("Start", scale=1)
307
+ stop_button = gr.Button("Stop", scale=1)
308
+ with gr.Column():
309
+ webcam_output = gr.Textbox(label="Response")
310
+ processed_view = gr.Image(label="Processed View")
311
+
312
+ status_text = gr.Textbox(label="Status", value="Ready")
313
+
314
+ start_button.click(
315
+ fn=lambda x: webcam_vision(x),
316
+ inputs=[webcam_instruction],
317
+ outputs=[webcam_output]
318
+ )
319
+
320
+ stop_button.click(
321
+ fn=lambda: "Stopped" if hasattr(webcam_vision, 'processor') and webcam_vision.processor.stop() else "Not running",
322
+ outputs=[status_text]
323
+ )
324
 
325
  demo.queue().launch(show_api=False, show_error=True)