|
import os |
|
import json |
|
import gradio as gr |
|
import tempfile |
|
import torch |
|
import spaces |
|
from pathlib import Path |
|
from transformers import AutoProcessor, AutoModelForVision2Seq |
|
import subprocess |
|
import logging |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
def load_examples(json_path: str) -> dict: |
|
with open(json_path, 'r') as f: |
|
return json.load(f) |
|
|
|
def format_duration(seconds: int) -> str: |
|
hours = seconds // 3600 |
|
minutes = (seconds % 3600) // 60 |
|
secs = seconds % 60 |
|
if hours > 0: |
|
return f"{hours}:{minutes:02d}:{secs:02d}" |
|
return f"{minutes}:{secs:02d}" |
|
|
|
def get_video_duration_seconds(video_path: str) -> float: |
|
"""Use ffprobe to get video duration in seconds.""" |
|
cmd = [ |
|
"ffprobe", |
|
"-v", "quiet", |
|
"-print_format", "json", |
|
"-show_format", |
|
video_path |
|
] |
|
result = subprocess.run(cmd, capture_output=True, text=True) |
|
info = json.loads(result.stdout) |
|
return float(info["format"]["duration"]) |
|
|
|
class VideoHighlightDetector: |
|
def __init__( |
|
self, |
|
model_path: str, |
|
device: str = "cuda", |
|
batch_size: int = 8 |
|
): |
|
self.device = device |
|
self.batch_size = batch_size |
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained(model_path) |
|
self.model = AutoModelForVision2Seq.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16, |
|
_attn_implementation="flash_attention_2" |
|
).to(device) |
|
|
|
def analyze_video_content(self, video_path: str) -> str: |
|
"""Analyze video content to determine its type and description.""" |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "video", "path": video_path}, |
|
{"type": "text", "text": "What type of video is this and what's happening in it? Be specific about the content type and general activities you observe."} |
|
] |
|
} |
|
] |
|
|
|
inputs = self.processor.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7) |
|
return self.processor.decode(outputs[0], skip_special_tokens=True) |
|
|
|
def determine_highlights(self, video_description: str) -> str: |
|
"""Determine what constitutes highlights based on video description.""" |
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": [{"type": "text", "text": "You are a professional video editor specializing in creating viral highlight reels."}] |
|
}, |
|
{ |
|
"role": "user", |
|
"content": [{"type": "text", "text": f"""Based on this video description: |
|
|
|
{video_description} |
|
|
|
List which rare segments should be included in a best of the best highlight."""}] |
|
} |
|
] |
|
|
|
inputs = self.processor.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
outputs = self.model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7) |
|
return self.processor.decode(outputs[0], skip_special_tokens=True) |
|
|
|
def process_segment(self, video_path: str, highlight_types: str) -> bool: |
|
"""Process a video segment and determine if it contains highlights.""" |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "video", "path": video_path}, |
|
{"type": "text", "text": f"""Do you see any of the following types of highlight moments in this video segment? |
|
|
|
Potential highlights to look for: |
|
{highlight_types} |
|
|
|
Only answer yes if you see any of those moments and answer no if you don't."""} |
|
] |
|
} |
|
] |
|
|
|
inputs = self.processor.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
outputs = self.model.generate(**inputs, max_new_tokens=64, do_sample=False) |
|
response = self.processor.decode(outputs[0], skip_special_tokens=True).lower() |
|
|
|
return "yes" in response |
|
|
|
def _concatenate_scenes( |
|
self, |
|
video_path: str, |
|
scene_times: list, |
|
output_path: str |
|
): |
|
"""Concatenate selected scenes into final video.""" |
|
if not scene_times: |
|
logger.warning("No scenes to concatenate, skipping.") |
|
return |
|
|
|
filter_complex_parts = [] |
|
concat_inputs = [] |
|
for i, (start_sec, end_sec) in enumerate(scene_times): |
|
filter_complex_parts.append( |
|
f"[0:v]trim=start={start_sec}:end={end_sec}," |
|
f"setpts=PTS-STARTPTS[v{i}];" |
|
) |
|
filter_complex_parts.append( |
|
f"[0:a]atrim=start={start_sec}:end={end_sec}," |
|
f"asetpts=PTS-STARTPTS[a{i}];" |
|
) |
|
concat_inputs.append(f"[v{i}][a{i}]") |
|
|
|
concat_filter = f"{''.join(concat_inputs)}concat=n={len(scene_times)}:v=1:a=1[outv][outa]" |
|
filter_complex = "".join(filter_complex_parts) + concat_filter |
|
|
|
cmd = [ |
|
"ffmpeg", |
|
"-y", |
|
"-i", video_path, |
|
"-filter_complex", filter_complex, |
|
"-map", "[outv]", |
|
"-map", "[outa]", |
|
"-c:v", "libx264", |
|
"-c:a", "aac", |
|
output_path |
|
] |
|
|
|
logger.info(f"Running ffmpeg command: {' '.join(cmd)}") |
|
subprocess.run(cmd, check=True) |
|
|
|
def create_ui(examples_path: str, model_path: str): |
|
examples_data = load_examples(examples_path) |
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("# Video Highlight Generator") |
|
gr.Markdown("Upload a video and get an automated highlight reel!") |
|
|
|
with gr.Row(): |
|
gr.Markdown("## Example Results") |
|
|
|
with gr.Row(): |
|
for example in examples_data["examples"]: |
|
with gr.Column(): |
|
gr.Video( |
|
value=example["original"]["url"], |
|
label=f"Original ({format_duration(example['original']['duration_seconds'])})", |
|
interactive=False |
|
) |
|
gr.Markdown(f"### {example['title']}") |
|
|
|
with gr.Column(): |
|
gr.Video( |
|
value=example["highlights"]["url"], |
|
label=f"Highlights ({format_duration(example['highlights']['duration_seconds'])})", |
|
interactive=False |
|
) |
|
with gr.Accordion("Chain of thought details", open=False): |
|
gr.Markdown(f"### Summary:\n{example['analysis']['video_description']}") |
|
gr.Markdown(f"### Highlights to search for:\n{example['analysis']['highlight_types']}") |
|
|
|
gr.Markdown("## Try It Yourself!") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_video = gr.Video( |
|
label="Upload your video (max 30 minutes)", |
|
interactive=True |
|
) |
|
process_btn = gr.Button("Process Video", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
output_video = gr.Video( |
|
label="Highlight Video", |
|
visible=False, |
|
interactive=False, |
|
) |
|
|
|
status = gr.Markdown() |
|
|
|
analysis_accordion = gr.Accordion( |
|
"Chain of thought details", |
|
open=True, |
|
visible=False |
|
) |
|
|
|
with analysis_accordion: |
|
video_description = gr.Markdown("", elem_id="video_desc") |
|
highlight_types = gr.Markdown("", elem_id="highlight_types") |
|
|
|
@spaces.GPU |
|
def on_process(video): |
|
|
|
yield [ |
|
"", |
|
"", |
|
"", |
|
gr.update(value=None, visible=False), |
|
gr.update(visible=False) |
|
] |
|
|
|
if not video: |
|
yield [ |
|
"Please upload a video", |
|
"", |
|
"", |
|
gr.update(visible=False), |
|
gr.update(visible=False) |
|
] |
|
return |
|
|
|
try: |
|
duration = get_video_duration_seconds(video) |
|
if duration > 1800: |
|
yield [ |
|
"Video must be shorter than 30 minutes", |
|
"", |
|
"", |
|
gr.update(visible=False), |
|
gr.update(visible=False) |
|
] |
|
return |
|
|
|
yield [ |
|
"Initializing video highlight detector...", |
|
"", |
|
"", |
|
gr.update(visible=False), |
|
gr.update(visible=False) |
|
] |
|
|
|
detector = VideoHighlightDetector( |
|
model_path=model_path, |
|
batch_size=8 |
|
) |
|
|
|
yield [ |
|
"Analyzing video content...", |
|
"", |
|
"", |
|
gr.update(visible=False), |
|
gr.update(visible=True) |
|
] |
|
|
|
video_desc = detector.analyze_video_content(video) |
|
formatted_desc = f"### Summary:\n {video_desc[:500] + '...' if len(video_desc) > 500 else video_desc}" |
|
|
|
yield [ |
|
"Determining highlight types...", |
|
formatted_desc, |
|
"", |
|
gr.update(visible=False), |
|
gr.update(visible=True) |
|
] |
|
|
|
highlights = detector.determine_highlights(video_desc) |
|
formatted_highlights = f"### Highlights to search for:\n {highlights[:500] + '...' if len(highlights) > 500 else highlights}" |
|
|
|
|
|
temp_dir = "temp_segments" |
|
os.makedirs(temp_dir, exist_ok=True) |
|
|
|
segment_length = 10.0 |
|
duration = get_video_duration_seconds(video) |
|
kept_segments = [] |
|
segments_processed = 0 |
|
total_segments = int(duration / segment_length) |
|
|
|
for start_time in range(0, int(duration), int(segment_length)): |
|
segments_processed += 1 |
|
progress = int((segments_processed / total_segments) * 100) |
|
|
|
yield [ |
|
f"Processing segments... {progress}% complete", |
|
formatted_desc, |
|
formatted_highlights, |
|
gr.update(visible=False), |
|
gr.update(visible=True) |
|
] |
|
|
|
|
|
segment_path = f"{temp_dir}/segment_{start_time}.mp4" |
|
end_time = min(start_time + segment_length, duration) |
|
|
|
cmd = [ |
|
"ffmpeg", |
|
"-y", |
|
"-i", video, |
|
"-ss", str(start_time), |
|
"-t", str(segment_length), |
|
"-c", "copy", |
|
segment_path |
|
] |
|
subprocess.run(cmd, check=True) |
|
|
|
|
|
if detector.process_segment(segment_path, highlights): |
|
kept_segments.append((start_time, end_time)) |
|
|
|
|
|
os.remove(segment_path) |
|
|
|
|
|
os.rmdir(temp_dir) |
|
|
|
|
|
if kept_segments: |
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: |
|
temp_output = tmp_file.name |
|
detector._concatenate_scenes(video, kept_segments, temp_output) |
|
|
|
yield [ |
|
"Processing complete!", |
|
formatted_desc, |
|
formatted_highlights, |
|
gr.update(value=temp_output, visible=True), |
|
gr.update(visible=True) |
|
] |
|
else: |
|
yield [ |
|
"No highlights detected in the video.", |
|
formatted_desc, |
|
formatted_highlights, |
|
gr.update(visible=False), |
|
gr.update(visible=True) |
|
] |
|
|
|
except Exception as e: |
|
logger.exception("Error processing video") |
|
yield [ |
|
f"Error processing video: {str(e)}", |
|
"", |
|
"", |
|
gr.update(visible=False), |
|
gr.update(visible=False) |
|
] |
|
finally: |
|
|
|
torch.cuda.empty_cache() |
|
|
|
process_btn.click( |
|
on_process, |
|
inputs=[input_video], |
|
outputs=[ |
|
status, |
|
video_description, |
|
highlight_types, |
|
output_video, |
|
analysis_accordion |
|
], |
|
queue=True, |
|
) |
|
|
|
return app |
|
|
|
if __name__ == "__main__": |
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
app = create_ui("video_spec.json", "HuggingFaceTB/SmolVLM2-2.2B-Instruct") |
|
app.launch() |