|
import os |
|
import json |
|
import gradio as gr |
|
import tempfile |
|
from PIL import Image, ImageDraw, ImageFont |
|
import cv2 |
|
from typing import Tuple, Optional |
|
import torch |
|
import spaces |
|
from pathlib import Path |
|
import time |
|
|
|
|
|
from video_highlight_detector import ( |
|
load_model, |
|
BatchedVideoHighlightDetector, |
|
get_video_duration_seconds |
|
) |
|
|
|
def load_examples(json_path: str) -> dict: |
|
"""Load pre-computed examples from JSON file""" |
|
with open(json_path, 'r') as f: |
|
return json.load(f) |
|
|
|
def format_duration(seconds: int) -> str: |
|
"""Convert seconds to MM:SS or HH:MM:SS format""" |
|
hours = seconds // 3600 |
|
minutes = (seconds % 3600) // 60 |
|
secs = seconds % 60 |
|
if hours > 0: |
|
return f"{hours}:{minutes:02d}:{secs:02d}" |
|
return f"{minutes}:{secs:02d}" |
|
|
|
def add_watermark(video_path: str, output_path: str): |
|
"""Add watermark to video using ffmpeg""" |
|
watermark_text = "🤗 SmolVLM2 Highlight" |
|
command = f"""ffmpeg -i {video_path} -vf \ |
|
"drawtext=text='{watermark_text}':fontcolor=white:fontsize=24:box=1:[email protected]:\ |
|
boxborderw=5:x=w-tw-10:y=h-th-10" \ |
|
-codec:a copy {output_path}""" |
|
os.system(command) |
|
|
|
def process_video( |
|
video_path: str, |
|
progress = gr.Progress() |
|
) -> Tuple[str, str, str, str]: |
|
""" |
|
Process video and return paths to: |
|
- Processed video with watermark |
|
- Video description |
|
- Highlight types |
|
- Error message (if any) |
|
""" |
|
try: |
|
|
|
duration = get_video_duration_seconds(video_path) |
|
if duration > 1200: |
|
return None, None, None, "Video must be shorter than 20 minutes" |
|
|
|
|
|
progress(0.1, desc="Loading model...") |
|
model, processor = load_model() |
|
detector = BatchedVideoHighlightDetector(model, processor) |
|
|
|
|
|
progress(0.2, desc="Analyzing video content...") |
|
video_description = detector.analyze_video_content(video_path) |
|
|
|
|
|
progress(0.3, desc="Determining highlight types...") |
|
highlight_types = detector.determine_highlights(video_description) |
|
|
|
|
|
progress(0.4, desc="Detecting and extracting highlights...") |
|
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: |
|
temp_output = tmp_file.name |
|
|
|
detector.create_highlight_video(video_path, temp_output) |
|
|
|
|
|
progress(0.9, desc="Adding watermark...") |
|
output_path = temp_output.replace('.mp4', '_watermark.mp4') |
|
add_watermark(temp_output, output_path) |
|
|
|
|
|
os.unlink(temp_output) |
|
|
|
|
|
video_description = video_description[:500] + "..." if len(video_description) > 500 else video_description |
|
highlight_types = highlight_types[:500] + "..." if len(highlight_types) > 500 else highlight_types |
|
|
|
return output_path, video_description, highlight_types, None |
|
|
|
except Exception as e: |
|
return None, None, None, f"Error processing video: {str(e)}" |
|
|
|
|
|
def create_ui(examples_path: str): |
|
"""Create the Gradio interface with optional thumbnails""" |
|
examples_data = load_examples(examples_path) |
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("# Video Highlight Generator") |
|
gr.Markdown("Upload a video (max 20 minutes) and get an automated highlight reel!") |
|
|
|
|
|
with gr.Row(): |
|
gr.Markdown("## Example Results") |
|
|
|
for example in examples_data["examples"]: |
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
video_component = gr.Video( |
|
example["original"]["url"], |
|
label=f"Original ({format_duration(example['original']['duration_seconds'])})", |
|
thumbnail=example["original"].get("thumbnail_url", None) |
|
) |
|
gr.Markdown(example["title"]) |
|
|
|
with gr.Column(): |
|
gr.Video( |
|
example["highlights"]["url"], |
|
label=f"Highlights ({format_duration(example['highlights']['duration_seconds'])})", |
|
thumbnail=example["highlights"].get("thumbnail_url", None) |
|
) |
|
with gr.Accordion("Analysis", open=False): |
|
gr.Markdown(example["analysis"]["video_description"]) |
|
gr.Markdown(example["analysis"]["highlight_types"]) |
|
|
|
|
|
gr.Markdown("## Try It Yourself!") |
|
with gr.Row(): |
|
input_video = gr.Video( |
|
label="Upload your video (max 20 minutes)", |
|
source="upload" |
|
) |
|
|
|
|
|
with gr.Row(visible=False) as results_row: |
|
with gr.Column(): |
|
video_description = gr.Markdown(label="Video Analysis") |
|
with gr.Column(): |
|
highlight_types = gr.Markdown(label="Detected Highlights") |
|
|
|
with gr.Row(visible=False) as output_row: |
|
output_video = gr.Video(label="Highlight Video") |
|
download_btn = gr.Button("Download Highlights") |
|
|
|
|
|
error_msg = gr.Markdown(visible=False) |
|
|
|
|
|
def on_upload(video): |
|
results_row.visible = False |
|
output_row.visible = False |
|
error_msg.visible = False |
|
|
|
if not video: |
|
error_msg.visible = True |
|
error_msg.value = "Please upload a video" |
|
return None, None, None, error_msg |
|
|
|
output_path, desc, highlights, err = process_video(video) |
|
|
|
if err: |
|
error_msg.visible = True |
|
error_msg.value = err |
|
return None, None, None, error_msg |
|
|
|
results_row.visible = True |
|
output_row.visible = True |
|
return output_path, desc, highlights, "" |
|
|
|
input_video.change( |
|
on_upload, |
|
inputs=[input_video], |
|
outputs=[output_video, video_description, highlight_types, error_msg] |
|
) |
|
|
|
|
|
download_btn.click( |
|
lambda x: x, |
|
inputs=[output_video], |
|
outputs=[output_video] |
|
) |
|
|
|
return app |
|
|
|
if __name__ == "__main__": |
|
app = create_ui("video_spec.json") |
|
app.launch() |
|
|