mfarre HF staff commited on
Commit
a8bd881
·
1 Parent(s): bd08551
Files changed (1) hide show
  1. app.py +131 -214
app.py CHANGED
@@ -1,13 +1,13 @@
1
  import os
2
  import json
3
  import gradio as gr
4
- import tempfile
5
  import torch
6
  import spaces
7
  from pathlib import Path
8
- from transformers import AutoProcessor, AutoModelForVision2Seq
9
  import subprocess
10
  import logging
 
 
11
 
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
@@ -16,16 +16,13 @@ def load_examples(json_path: str) -> dict:
16
  with open(json_path, 'r') as f:
17
  return json.load(f)
18
 
19
- def format_duration(seconds: int) -> str:
20
- hours = seconds // 3600
21
- minutes = (seconds % 3600) // 60
22
- secs = seconds % 60
23
- if hours > 0:
24
- return f"{hours}:{minutes:02d}:{secs:02d}"
25
- return f"{minutes}:{secs:02d}"
26
 
27
  def get_video_duration_seconds(video_path: str) -> float:
28
- """Use ffprobe to get video duration in seconds."""
29
  cmd = [
30
  "ffprobe",
31
  "-v", "quiet",
@@ -51,12 +48,10 @@ class VideoHighlightDetector:
51
  self.processor = AutoProcessor.from_pretrained(model_path)
52
  self.model = AutoModelForVision2Seq.from_pretrained(
53
  model_path,
54
- torch_dtype=torch.bfloat16,
55
- # _attn_implementation="flash_attention_2"
56
  ).to(device)
57
 
58
  def analyze_video_content(self, video_path: str) -> str:
59
- """Analyze video content to determine its type and description."""
60
  system_message = "You are a helpful assistant that can understand videos. Describe what type of video this is and what's happening in it."
61
  messages = [
62
  {
@@ -83,24 +78,44 @@ class VideoHighlightDetector:
83
  outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
84
  return self.processor.decode(outputs[0], skip_special_tokens=True).lower().split("assistant: ")[1]
85
 
86
- def determine_highlights(self, video_description: str) -> str:
87
- """Determine what constitutes highlights based on video description."""
88
  messages = [
89
  {
90
  "role": "system",
91
- "content": [{"type": "text", "text": "You are a professional video editor specializing in creating viral highlight reels. You understand that the most engaging highlights are brief and focus only on exceptional moments that are statistically rare or particularly dramatic. Moments that would make viewers say 'I can't believe that happened!"}]
92
  },
93
  {
94
  "role": "user",
95
- "content": [{"type": "text", "text": f"""Here is a description of a video:
96
-
97
- {video_description}
98
-
99
- Based on this description, list which rare segments should be included in a best of the best highlight."""}]
100
  }
101
  ]
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- print(messages)
 
 
 
 
 
 
 
 
 
 
104
 
105
  inputs = self.processor.apply_chat_template(
106
  messages,
@@ -114,22 +129,15 @@ class VideoHighlightDetector:
114
  return self.processor.decode(outputs[0], skip_special_tokens=True).split("Assistant: ")[1]
115
 
116
  def process_segment(self, video_path: str, highlight_types: str) -> bool:
117
- """Process a video segment and determine if it contains highlights."""
118
  messages = [
119
  {
120
  "role": "user",
121
  "content": [
122
  {"type": "video", "path": video_path},
123
- {"type": "text", "text": f"""{highlight_types}
124
-
125
-
126
- Do you see any of those elements in the video? answer yes if you do and answer no if you don't."""}
127
  ]
128
  }
129
  ]
130
-
131
- print(messages)
132
-
133
 
134
  inputs = self.processor.apply_chat_template(
135
  messages,
@@ -141,82 +149,53 @@ class VideoHighlightDetector:
141
 
142
  outputs = self.model.generate(**inputs, max_new_tokens=64, do_sample=False)
143
  response = self.processor.decode(outputs[0], skip_special_tokens=True).lower().split("assistant: ")[1]
144
- print(f"Segment response {response}")
145
  return "yes" in response
146
 
147
- def _concatenate_scenes(
148
- self,
149
- video_path: str,
150
- scene_times: list,
151
- output_path: str
152
- ):
153
- """Concatenate selected scenes into final video."""
154
- if not scene_times:
155
- logger.warning("No scenes to concatenate, skipping.")
156
- return
157
-
158
- filter_complex_parts = []
159
- concat_inputs = []
160
- for i, (start_sec, end_sec) in enumerate(scene_times):
161
- filter_complex_parts.append(
162
- f"[0:v]trim=start={start_sec}:end={end_sec},"
163
- f"setpts=PTS-STARTPTS[v{i}];"
164
- )
165
- filter_complex_parts.append(
166
- f"[0:a]atrim=start={start_sec}:end={end_sec},"
167
- f"asetpts=PTS-STARTPTS[a{i}];"
168
- )
169
- concat_inputs.append(f"[v{i}][a{i}]")
170
-
171
- concat_filter = f"{''.join(concat_inputs)}concat=n={len(scene_times)}:v=1:a=1[outv][outa]"
172
- filter_complex = "".join(filter_complex_parts) + concat_filter
173
-
174
- cmd = [
175
- "ffmpeg",
176
- "-y",
177
- "-i", video_path,
178
- "-filter_complex", filter_complex,
179
- "-map", "[outv]",
180
- "-map", "[outa]",
181
- "-c:v", "libx264",
182
- "-c:a", "aac",
183
- output_path
184
- ]
185
-
186
- logger.info(f"Running ffmpeg command: {' '.join(cmd)}")
187
- subprocess.run(cmd, check=True)
188
 
189
  def create_ui(examples_path: str, model_path: str):
190
  examples_data = load_examples(examples_path)
191
 
192
  with gr.Blocks() as app:
193
- gr.Markdown("# Video Highlight Generator")
194
- gr.Markdown("Upload a video and get an automated highlight reel!")
195
 
196
- with gr.Row():
197
- gr.Markdown("## Example Results")
198
-
199
- with gr.Row():
200
- for example in examples_data["examples"]:
201
- with gr.Column():
202
- gr.Video(
203
- value=example["original"]["url"],
204
- label=f"Original ({format_duration(example['original']['duration_seconds'])})",
205
- interactive=False
206
- )
207
- gr.Markdown(f"### {example['title']}")
208
-
209
- with gr.Column():
210
- gr.Video(
211
- value=example["highlights"]["url"],
212
- label=f"Highlights ({format_duration(example['highlights']['duration_seconds'])})",
213
- interactive=False
214
- )
215
- with gr.Accordion("Chain of thought details", open=False):
216
- gr.Markdown(f"### Summary:\n{example['analysis']['video_description']}")
217
- gr.Markdown(f"### Highlights to search for:\n{example['analysis']['highlight_types']}")
218
-
219
- gr.Markdown("## Try It Yourself!")
220
  with gr.Row():
221
  with gr.Column(scale=1):
222
  input_video = gr.Video(
@@ -226,185 +205,128 @@ def create_ui(examples_path: str, model_path: str):
226
  process_btn = gr.Button("Process Video", variant="primary")
227
 
228
  with gr.Column(scale=1):
229
- output_video = gr.Video(
230
- label="Highlight Video",
231
  visible=False,
232
  interactive=False,
233
  )
234
-
235
  status = gr.Markdown()
236
-
237
  analysis_accordion = gr.Accordion(
238
- "Chain of thought details",
239
  open=True,
240
  visible=False
241
  )
242
 
243
  with analysis_accordion:
244
- video_description = gr.Markdown("", elem_id="video_desc")
245
- highlight_types = gr.Markdown("", elem_id="highlight_types")
246
 
247
  @spaces.GPU
248
  def on_process(video):
249
- # Clear all components when starting new processing
250
- yield [
251
- "", # Clear status
252
- "", # Clear video description
253
- "", # Clear highlight types
254
- gr.update(value=None, visible=False), # Clear video
255
- gr.update(visible=False) # Hide accordion
256
- ]
257
-
258
  if not video:
259
- yield [
 
260
  "Please upload a video",
261
  "",
262
  "",
263
- gr.update(visible=False),
264
  gr.update(visible=False)
265
  ]
266
- return
267
 
268
  try:
269
  duration = get_video_duration_seconds(video)
270
- if duration > 1800: # 30 minutes
271
- yield [
 
272
  "Video must be shorter than 30 minutes",
273
  "",
274
  "",
275
- gr.update(visible=False),
276
  gr.update(visible=False)
277
  ]
278
- return
279
 
280
- yield [
281
- "Initializing video highlight detector...",
282
- "",
283
- "",
284
- gr.update(visible=False),
285
- gr.update(visible=False)
286
- ]
287
-
288
- detector = VideoHighlightDetector(
289
- model_path=model_path,
290
- batch_size=8
291
- )
292
-
293
- yield [
294
- "Analyzing video content...",
295
- "",
296
- "",
297
- gr.update(visible=False),
298
- gr.update(visible=True)
299
- ]
300
 
 
301
  video_desc = detector.analyze_video_content(video)
302
- formatted_desc = f"### Summary:\n {video_desc[:500] + '...' if len(video_desc) > 500 else video_desc}"
303
-
304
- yield [
305
- "Determining highlight types...",
306
- formatted_desc,
307
- "",
308
- gr.update(visible=False),
309
- gr.update(visible=True)
310
- ]
311
 
 
312
  highlights = detector.determine_highlights(video_desc)
313
- formatted_highlights = f"### Highlights to search for:\n {highlights[:500] + '...' if len(highlights) > 500 else highlights}"
314
-
315
- # Split video into segments
316
- temp_dir = "temp_segments"
317
- os.makedirs(temp_dir, exist_ok=True)
318
 
 
319
  segment_length = 10.0
320
- duration = get_video_duration_seconds(video)
321
  kept_segments = []
322
- segments_processed = 0
323
- total_segments = int(duration / segment_length)
324
-
325
  for start_time in range(0, int(duration), int(segment_length)):
326
- segments_processed += 1
327
- progress = int((segments_processed / total_segments) * 100)
328
-
329
- yield [
330
- f"Processing segments... {progress}% complete",
331
- formatted_desc,
332
- formatted_highlights,
333
- gr.update(visible=False),
334
- gr.update(visible=True)
335
- ]
336
-
337
- # Create segment
338
- segment_path = f"{temp_dir}/segment_{start_time}.mp4"
339
  end_time = min(start_time + segment_length, duration)
340
 
341
- cmd = [
342
- "ffmpeg",
343
- "-y",
344
- "-i", video,
345
- "-ss", str(start_time),
346
- "-t", str(segment_length),
347
- "-c:v", "libx264",
348
- "-preset", "ultrafast", # Use ultrafast preset for speed
349
- "-pix_fmt", "yuv420p", # Ensure compatible pixel format
350
- segment_path
351
- ]
352
- subprocess.run(cmd, check=True)
353
-
354
- # Process segment
355
- if detector.process_segment(segment_path, highlights):
356
- print("KEEPING SEGMENT")
357
- kept_segments.append((start_time, end_time))
358
-
359
- # Clean up segment file
360
- os.remove(segment_path)
361
-
362
- # Remove temp directory
363
- os.rmdir(temp_dir)
364
-
365
- # Create final video
366
  if kept_segments:
367
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
368
- temp_output = tmp_file.name
369
- detector._concatenate_scenes(video, kept_segments, temp_output)
 
 
 
 
370
 
371
- yield [
372
- "Processing complete!",
 
373
  formatted_desc,
374
  formatted_highlights,
375
- gr.update(value=temp_output, visible=True),
376
  gr.update(visible=True)
377
  ]
378
  else:
379
- yield [
 
380
  "No highlights detected in the video.",
381
  formatted_desc,
382
  formatted_highlights,
383
- gr.update(visible=False),
384
  gr.update(visible=True)
385
  ]
386
 
387
  except Exception as e:
388
  logger.exception("Error processing video")
389
- yield [
 
390
  f"Error processing video: {str(e)}",
391
  "",
392
  "",
393
- gr.update(visible=False),
394
  gr.update(visible=False)
395
  ]
396
  finally:
397
- # Clean up
398
  torch.cuda.empty_cache()
399
 
400
  process_btn.click(
401
  on_process,
402
  inputs=[input_video],
403
  outputs=[
 
404
  status,
405
  video_description,
406
  highlight_types,
407
- output_video,
408
  analysis_accordion
409
  ],
410
  queue=True,
@@ -413,10 +335,5 @@ def create_ui(examples_path: str, model_path: str):
413
  return app
414
 
415
  if __name__ == "__main__":
416
- # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
417
-
418
- # Initialize CUDA
419
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
420
-
421
  app = create_ui("video_spec.json", "HuggingFaceTB/SmolVLM2-2.2B-Instruct")
422
  app.launch()
 
1
  import os
2
  import json
3
  import gradio as gr
 
4
  import torch
5
  import spaces
6
  from pathlib import Path
 
7
  import subprocess
8
  import logging
9
+ import xml.etree.ElementTree as ET
10
+ from xml.dom import minidom
11
 
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
 
16
  with open(json_path, 'r') as f:
17
  return json.load(f)
18
 
19
+ def format_duration(seconds: float) -> str:
20
+ hours = int(seconds // 3600)
21
+ minutes = int((seconds % 3600) // 60)
22
+ secs = int(seconds % 60)
23
+ return f"{hours:02d}:{minutes:02d}:{secs:02d}"
 
 
24
 
25
  def get_video_duration_seconds(video_path: str) -> float:
 
26
  cmd = [
27
  "ffprobe",
28
  "-v", "quiet",
 
48
  self.processor = AutoProcessor.from_pretrained(model_path)
49
  self.model = AutoModelForVision2Seq.from_pretrained(
50
  model_path,
51
+ torch_dtype=torch.bfloat16
 
52
  ).to(device)
53
 
54
  def analyze_video_content(self, video_path: str) -> str:
 
55
  system_message = "You are a helpful assistant that can understand videos. Describe what type of video this is and what's happening in it."
56
  messages = [
57
  {
 
78
  outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
79
  return self.processor.decode(outputs[0], skip_special_tokens=True).lower().split("assistant: ")[1]
80
 
81
+ def analyze_segment(self, video_path: str) -> str:
82
+ """Analyze a specific video segment and provide a brief description."""
83
  messages = [
84
  {
85
  "role": "system",
86
+ "content": [{"type": "text", "text": "Describe what is happening in this specific video segment in a brief, concise way."}]
87
  },
88
  {
89
  "role": "user",
90
+ "content": [
91
+ {"type": "video", "path": video_path},
92
+ {"type": "text", "text": "What is happening in this segment? Provide a brief description."}
93
+ ]
 
94
  }
95
  ]
96
+
97
+ inputs = self.processor.apply_chat_template(
98
+ messages,
99
+ add_generation_prompt=True,
100
+ tokenize=True,
101
+ return_dict=True,
102
+ return_tensors="pt"
103
+ ).to(self.device)
104
+
105
+ outputs = self.model.generate(**inputs, max_new_tokens=128, do_sample=True, temperature=0.7)
106
+ return self.processor.decode(outputs[0], skip_special_tokens=True).split("Assistant: ")[1]
107
 
108
+ def determine_highlights(self, video_description: str) -> str:
109
+ messages = [
110
+ {
111
+ "role": "system",
112
+ "content": [{"type": "text", "text": "You are a professional video editor specializing in creating viral highlight reels."}]
113
+ },
114
+ {
115
+ "role": "user",
116
+ "content": [{"type": "text", "text": f"Based on this description, list which segments should be included in highlights: {video_description}"}]
117
+ }
118
+ ]
119
 
120
  inputs = self.processor.apply_chat_template(
121
  messages,
 
129
  return self.processor.decode(outputs[0], skip_special_tokens=True).split("Assistant: ")[1]
130
 
131
  def process_segment(self, video_path: str, highlight_types: str) -> bool:
 
132
  messages = [
133
  {
134
  "role": "user",
135
  "content": [
136
  {"type": "video", "path": video_path},
137
+ {"type": "text", "text": f"Do you see any of these elements in the video: {highlight_types}? Answer yes or no."}
 
 
 
138
  ]
139
  }
140
  ]
 
 
 
141
 
142
  inputs = self.processor.apply_chat_template(
143
  messages,
 
149
 
150
  outputs = self.model.generate(**inputs, max_new_tokens=64, do_sample=False)
151
  response = self.processor.decode(outputs[0], skip_special_tokens=True).lower().split("assistant: ")[1]
 
152
  return "yes" in response
153
 
154
+ def create_xspf_playlist(video_path: str, segments: list, descriptions: list) -> str:
155
+ """Create XSPF playlist from segments with descriptions."""
156
+ root = ET.Element("playlist", version="1", xmlns="http://xspf.org/ns/0/")
157
+
158
+ # Get video filename for the title
159
+ video_filename = os.path.basename(video_path)
160
+ title = ET.SubElement(root, "title")
161
+ title.text = f"{video_filename} - Highlights"
162
+
163
+ tracklist = ET.SubElement(root, "trackList")
164
+
165
+ for idx, ((start_time, end_time), description) in enumerate(zip(segments, descriptions)):
166
+ track = ET.SubElement(tracklist, "track")
167
+
168
+ location = ET.SubElement(track, "location")
169
+ location.text = f"file:///{video_filename}"
170
+
171
+ title = ET.SubElement(track, "title")
172
+ title.text = f"Highlight {idx + 1}"
173
+
174
+ annotation = ET.SubElement(track, "annotation")
175
+ annotation.text = description
176
+
177
+ start_meta = ET.SubElement(track, "meta", rel="start")
178
+ start_meta.text = format_duration(start_time)
179
+
180
+ end_meta = ET.SubElement(track, "meta", rel="end")
181
+ end_meta.text = format_duration(end_time)
182
+
183
+ # Add VLC extension
184
+ extension = ET.SubElement(root, "extension", application="http://www.videolan.org/vlc/playlist/0")
185
+ for i in range(len(segments)):
186
+ item = ET.SubElement(extension, "vlc:item", tid=str(i))
187
+
188
+ # Convert to string with pretty printing
189
+ xml_str = minidom.parseString(ET.tostring(root)).toprettyxml(indent=" ")
190
+ return xml_str
 
 
 
 
191
 
192
  def create_ui(examples_path: str, model_path: str):
193
  examples_data = load_examples(examples_path)
194
 
195
  with gr.Blocks() as app:
196
+ gr.Markdown("# Video Highlight Playlist Generator")
197
+ gr.Markdown("Upload a video and get an XSPF playlist of highlights!")
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  with gr.Row():
200
  with gr.Column(scale=1):
201
  input_video = gr.Video(
 
205
  process_btn = gr.Button("Process Video", variant="primary")
206
 
207
  with gr.Column(scale=1):
208
+ output_playlist = gr.File(
209
+ label="Highlight Playlist (XSPF)",
210
  visible=False,
211
  interactive=False,
212
  )
 
213
  status = gr.Markdown()
214
+
215
  analysis_accordion = gr.Accordion(
216
+ "Analysis Details",
217
  open=True,
218
  visible=False
219
  )
220
 
221
  with analysis_accordion:
222
+ video_description = gr.Markdown("")
223
+ highlight_types = gr.Markdown("")
224
 
225
  @spaces.GPU
226
  def on_process(video):
 
 
 
 
 
 
 
 
 
227
  if not video:
228
+ return [
229
+ None,
230
  "Please upload a video",
231
  "",
232
  "",
 
233
  gr.update(visible=False)
234
  ]
 
235
 
236
  try:
237
  duration = get_video_duration_seconds(video)
238
+ if duration > 18000: # 300 minutes
239
+ return [
240
+ None,
241
  "Video must be shorter than 30 minutes",
242
  "",
243
  "",
 
244
  gr.update(visible=False)
245
  ]
 
246
 
247
+ detector = VideoHighlightDetector(model_path=model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
+ # Analyze video content
250
  video_desc = detector.analyze_video_content(video)
251
+ formatted_desc = f"### Video Summary:\n{video_desc}"
 
 
 
 
 
 
 
 
252
 
253
+ # Determine highlight types
254
  highlights = detector.determine_highlights(video_desc)
255
+ formatted_highlights = f"### Highlight Criteria:\n{highlights}"
 
 
 
 
256
 
257
+ # Process video in segments
258
  segment_length = 10.0
 
259
  kept_segments = []
260
+ segment_descriptions = []
261
+
 
262
  for start_time in range(0, int(duration), int(segment_length)):
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  end_time = min(start_time + segment_length, duration)
264
 
265
+ # Create temporary segment
266
+ with tempfile.NamedTemporaryFile(suffix='.mp4') as temp_segment:
267
+ cmd = [
268
+ "ffmpeg",
269
+ "-y",
270
+ "-i", video,
271
+ "-ss", str(start_time),
272
+ "-t", str(segment_length),
273
+ "-c:v", "libx264",
274
+ "-preset", "ultrafast",
275
+ temp_segment.name
276
+ ]
277
+ subprocess.run(cmd, check=True)
278
+
279
+ if detector.process_segment(temp_segment.name, highlights):
280
+ # Get segment description
281
+ description = detector.analyze_segment(temp_segment.name)
282
+ kept_segments.append((start_time, end_time))
283
+ segment_descriptions.append(description)
284
+
 
 
 
 
 
285
  if kept_segments:
286
+ # Create XSPF playlist
287
+ playlist_content = create_xspf_playlist(video, kept_segments, segment_descriptions)
288
+
289
+ # Save playlist to temporary file
290
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.xspf', delete=False) as f:
291
+ f.write(playlist_content)
292
+ playlist_path = f.name
293
 
294
+ return [
295
+ gr.update(value=playlist_path, visible=True),
296
+ "Processing complete! Download the XSPF playlist.",
297
  formatted_desc,
298
  formatted_highlights,
 
299
  gr.update(visible=True)
300
  ]
301
  else:
302
+ return [
303
+ None,
304
  "No highlights detected in the video.",
305
  formatted_desc,
306
  formatted_highlights,
 
307
  gr.update(visible=True)
308
  ]
309
 
310
  except Exception as e:
311
  logger.exception("Error processing video")
312
+ return [
313
+ None,
314
  f"Error processing video: {str(e)}",
315
  "",
316
  "",
 
317
  gr.update(visible=False)
318
  ]
319
  finally:
 
320
  torch.cuda.empty_cache()
321
 
322
  process_btn.click(
323
  on_process,
324
  inputs=[input_video],
325
  outputs=[
326
+ output_playlist,
327
  status,
328
  video_description,
329
  highlight_types,
 
330
  analysis_accordion
331
  ],
332
  queue=True,
 
335
  return app
336
 
337
  if __name__ == "__main__":
 
 
 
 
 
338
  app = create_ui("video_spec.json", "HuggingFaceTB/SmolVLM2-2.2B-Instruct")
339
  app.launch()