mfarre HF staff commited on
Commit
7f59bd0
·
1 Parent(s): 9c83cd8

update alg

Browse files
Files changed (1) hide show
  1. app.py +63 -18
app.py CHANGED
@@ -108,18 +108,30 @@ class VideoHighlightDetector:
108
  outputs = self.model.generate(**inputs, max_new_tokens=128, do_sample=True, temperature=0.7)
109
  return self.processor.decode(outputs[0], skip_special_tokens=True).split("Assistant: ")[1]
110
 
111
- def determine_highlights(self, video_description: str) -> str:
112
- """Determine what constitutes highlights based on video description."""
 
 
 
 
 
 
 
 
 
113
  messages = [
114
  {
115
  "role": "system",
116
- "content": [{"type": "text", "text": "You are a highlight editor. List archetypal dramatic moments that would make compelling highlights if they appear in the video. Each moment should be specific enough to be recognizable but generic enough to potentially exist in any video of this type."}]
117
  },
118
  {
119
  "role": "user",
120
- "content": [{"type": "text", "text": f"""Here is a description of a video:\n\n{video_description}\n\nList potential highlight moments to look for in this video:"""}]
121
  }
122
  ]
 
 
 
123
 
124
  inputs = self.processor.apply_chat_template(
125
  messages,
@@ -265,7 +277,7 @@ def create_ui(examples_path: str, model_path: str):
265
  gr.update(visible=False)
266
  ]
267
 
268
- detector = VideoHighlightDetector(model_path=model_path)
269
 
270
  yield [
271
  None,
@@ -287,20 +299,21 @@ def create_ui(examples_path: str, model_path: str):
287
  gr.update(visible=True)
288
  ]
289
 
290
- # Determine highlight types
291
- highlights = detector.determine_highlights(video_desc)
292
- formatted_highlights = f"### Highlight Criteria:\n{highlights}"
293
 
294
  # Process video in segments
295
  segment_length = 10.0
296
- kept_segments = []
297
- segment_descriptions = []
 
 
298
  segments_processed = 0
299
  total_segments = int(duration / segment_length)
300
 
301
  for start_time in range(0, int(duration), int(segment_length)):
302
  end_time = min(start_time + segment_length, duration)
303
- segments_processed +=1
304
  progress = int((segments_processed / total_segments) * 100)
305
 
306
  yield [
@@ -325,24 +338,56 @@ def create_ui(examples_path: str, model_path: str):
325
  ]
326
  subprocess.run(cmd, check=True)
327
 
328
- if detector.process_segment(temp_segment.name, highlights):
329
- # Get segment description
 
 
 
 
 
330
  description = detector.analyze_segment(temp_segment.name)
331
- kept_segments.append((start_time, end_time))
332
- segment_descriptions.append(description)
 
 
 
 
 
 
 
333
 
334
- if kept_segments:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  # Create XSPF playlist
336
- playlist_content = create_xspf_playlist(video, kept_segments, segment_descriptions)
337
 
338
  # Save playlist to temporary file
339
  with tempfile.NamedTemporaryFile(mode='w', suffix='.xspf', delete=False) as f:
340
  f.write(playlist_content)
341
  playlist_path = f.name
342
 
 
 
343
  yield [
344
  gr.update(value=playlist_path, visible=True),
345
- "Processing complete! You can download the playlist.",
346
  formatted_desc,
347
  formatted_highlights,
348
  gr.update(visible=True)
 
108
  outputs = self.model.generate(**inputs, max_new_tokens=128, do_sample=True, temperature=0.7)
109
  return self.processor.decode(outputs[0], skip_special_tokens=True).split("Assistant: ")[1]
110
 
111
+ def determine_highlights(self, video_description: str, prompt_num: int = 1) -> str:
112
+ """Determine what constitutes highlights based on video description with different prompts."""
113
+ system_prompts = {
114
+ 1: "You are a highlight editor. List archetypal dramatic moments that would make compelling highlights if they appear in the video. Each moment should be specific enough to be recognizable but generic enough to potentially exist in any video of this type.",
115
+ 2: "You are a helpful visual-language assistant that can understand videos and edit. You are tasked helping the user to create highlight reels for videos. Generally, highlights should be relatively rare and important events in the video in question."
116
+ }
117
+ user_prompts = {
118
+ 1: "List potential highlight moments to look for in this video:",
119
+ 2: "List dramatic moments that would make compelling highlights if they appear in the video. Each moment should be specific enough to be recognizable but generic enough to potentially exist in any video of this type:"
120
+ }
121
+
122
  messages = [
123
  {
124
  "role": "system",
125
+ "content": [{"type": "text", "text": system_prompts[prompt_num]}]
126
  },
127
  {
128
  "role": "user",
129
+ "content": [{"type": "text", "text": f"""Here is a description of a video:\n\n{video_description}\n\n{user_prompts[prompt_num]}"""}]
130
  }
131
  ]
132
+
133
+ print(f"Using prompt {prompt_num} for highlight detection")
134
+ print(messages)
135
 
136
  inputs = self.processor.apply_chat_template(
137
  messages,
 
277
  gr.update(visible=False)
278
  ]
279
 
280
+ detector = VideoHighlightDetector(model_path=model_path, batch_size=16)
281
 
282
  yield [
283
  None,
 
299
  gr.update(visible=True)
300
  ]
301
 
302
+ highlights1 = detector.determine_highlights(video_desc, prompt_num=1)
303
+ highlights2 = detector.determine_highlights(video_desc, prompt_num=2)
304
+ formatted_highlights = f"### Highlight Criteria:\nSet 1:\n{highlights1}\n\nSet 2:\n{highlights2}"
305
 
306
  # Process video in segments
307
  segment_length = 10.0
308
+ kept_segments1 = []
309
+ kept_segments2 = []
310
+ segment_descriptions1 = []
311
+ segment_descriptions2 = []
312
  segments_processed = 0
313
  total_segments = int(duration / segment_length)
314
 
315
  for start_time in range(0, int(duration), int(segment_length)):
316
  end_time = min(start_time + segment_length, duration)
 
317
  progress = int((segments_processed / total_segments) * 100)
318
 
319
  yield [
 
338
  ]
339
  subprocess.run(cmd, check=True)
340
 
341
+ # Process with both highlight sets
342
+ if detector.process_segment(temp_segment.name, highlights1):
343
+ description = detector.analyze_segment(temp_segment.name)
344
+ kept_segments1.append((start_time, end_time))
345
+ segment_descriptions1.append(description)
346
+
347
+ if detector.process_segment(temp_segment.name, highlights2):
348
  description = detector.analyze_segment(temp_segment.name)
349
+ kept_segments2.append((start_time, end_time))
350
+ segment_descriptions2.append(description)
351
+
352
+ segments_processed += 1
353
+
354
+ # Calculate percentages of video kept for each highlight set
355
+ total_duration = duration
356
+ duration1 = sum(end - start for start, end in kept_segments1)
357
+ duration2 = sum(end - start for start, end in kept_segments2)
358
 
359
+ percent1 = (duration1 / total_duration) * 100
360
+ percent2 = (duration2 / total_duration) * 100
361
+
362
+ print(f"Highlight set 1: {percent1:.1f}% of video")
363
+ print(f"Highlight set 2: {percent2:.1f}% of video")
364
+
365
+ # Choose the set with lower percentage unless it's zero
366
+ if (0 < percent2 <= percent1 or percent1 == 0):
367
+ final_segments = kept_segments2
368
+ segment_descriptions = segment_descriptions2
369
+ selected_set = "2"
370
+ percent_used = percent2
371
+ else:
372
+ final_segments = kept_segments1
373
+ segment_descriptions = segment_descriptions1
374
+ selected_set = "1"
375
+ percent_used = percent1
376
+
377
+ if final_segments:
378
  # Create XSPF playlist
379
+ playlist_content = create_xspf_playlist(video, final_segments, segment_descriptions)
380
 
381
  # Save playlist to temporary file
382
  with tempfile.NamedTemporaryFile(mode='w', suffix='.xspf', delete=False) as f:
383
  f.write(playlist_content)
384
  playlist_path = f.name
385
 
386
+ completion_message = f"Processing complete! Using highlight set {selected_set} ({percent_used:.1f}% of video). You can download the playlist."
387
+
388
  yield [
389
  gr.update(value=playlist_path, visible=True),
390
+ completion_message,
391
  formatted_desc,
392
  formatted_highlights,
393
  gr.update(visible=True)