mfarre's picture
mfarre HF staff
.
946878e
raw
history blame
11 kB
import os
import json
import gradio as gr
import tempfile
from PIL import Image, ImageDraw, ImageFont
import cv2
from typing import Tuple, Optional
import torch
from pathlib import Path
import time
import torch
import spaces
import os
from video_highlight_detector import (
load_model,
BatchedVideoHighlightDetector,
get_video_duration_seconds
)
def load_examples(json_path: str) -> dict:
with open(json_path, 'r') as f:
return json.load(f)
def format_duration(seconds: int) -> str:
hours = seconds // 3600
minutes = (seconds % 3600) // 60
secs = seconds % 60
if hours > 0:
return f"{hours}:{minutes:02d}:{secs:02d}"
return f"{minutes}:{secs:02d}"
# @spaces.GPU
# def process_video(
# video_path: str,
# progress = gr.Progress()
# ) -> Tuple[str, str, str, str]:
# try:
# # duration = get_video_duration_seconds(video_path)
# # if duration > 1200: # 20 minutes
# # return None, None, None, "Video must be shorter than 20 minutes"
# progress(0.1, desc="Loading model...")
# model, processor = load_model()
# detector = BatchedVideoHighlightDetector(model, processor, batch_size=8)
# progress(0.2, desc="Analyzing video content...")
# video_description = detector.analyze_video_content(video_path)
# progress(0.3, desc="Determining highlight types...")
# highlight_types = detector.determine_highlights(video_description)
# progress(0.4, desc="Detecting and extracting highlights...")
# with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
# output_path = tmp_file.name
# detector.create_highlight_video(video_path, output_path)
# # progress(0.9, desc="Adding watermark...")
# # output_path = temp_output.replace('.mp4', '_watermark.mp4')
# # add_watermark(temp_output, output_path)
# os.unlink(output_path)
# progress(1.0, desc="Complete!")
# video_description = video_description[:500] + "..." if len(video_description) > 500 else video_description
# highlight_types = highlight_types[:500] + "..." if len(highlight_types) > 500 else highlight_types
# return output_path, video_description, highlight_types, None
# except Exception as e:
# return None, None, None, f"Error processing video: {str(e)}"
def create_ui(examples_path: str):
examples_data = load_examples(examples_path)
with gr.Blocks() as app:
gr.Markdown("# Video Highlight Generator")
gr.Markdown("Upload a video and get an automated highlight reel!")
with gr.Row():
gr.Markdown("## Example Results")
with gr.Row():
for example in examples_data["examples"]:
with gr.Column():
gr.Video(
value=example["original"]["url"],
label=f"Original ({format_duration(example['original']['duration_seconds'])})",
interactive=False
)
gr.Markdown(f"### {example['title']}")
with gr.Column():
gr.Video(
value=example["highlights"]["url"],
label=f"Highlights ({format_duration(example['highlights']['duration_seconds'])})",
interactive=False
)
with gr.Accordion("Model chain of thought details", open=False):
gr.Markdown(f"#Summary: {example['analysis']['video_description']}")
gr.Markdown(f"#Highlights to search for: {example['analysis']['highlight_types']}")
# Main interface section
gr.Markdown("## Try It Yourself!")
with gr.Row():
# Left column: Upload and Process
with gr.Column(scale=1):
input_video = gr.Video(
label="Upload your video (max 20 minutes)",
interactive=True
)
process_btn = gr.Button("Process Video", variant="primary")
# Right column: Progress and Analysis
with gr.Column(scale=1):
# Output video (initially hidden)
output_video = gr.Video(
label="Highlight Video",
visible=False,
interactive=False,
downloadable=True
)
status = gr.Markdown()
with gr.Accordion("Model chain of thought details", open=True, visible=False) as analysis_accordion:
video_description = gr.Markdown("", elem_id="video_desc")
highlight_types = gr.Markdown("", elem_id="highlight_types")
@spaces.GPU
def on_process(video, progress=gr.Progress()):
if not video:
return {
status: "Please upload a video",
video_description: "",
highlight_types: "",
output_video: gr.update(visible=False)
}
try:
duration = get_video_duration_seconds(video)
if duration > 1200: # 20 minutes
return {
status: "Video must be shorter than 20 minutes",
video_description: "",
highlight_types: "",
output_video: gr.update(visible=False)
}
progress(0.1, desc="Loading model...")
status.value = "Loading model..."
model, processor = load_model()
detector = BatchedVideoHighlightDetector(model, processor, batch_size=8)
progress(0.2, desc="Analyzing video content...")
status.value = "Analyzing video content..."
video_desc = detector.analyze_video_content(video)
# Update description in real-time
video_description.value = f"#Summary: {video_desc[:500] + '...' if len(video_desc) > 500 else video_desc}"
progress(0.3, desc="Determining highlight types...")
status.value = "Determining highlight types..."
highlights = detector.determine_highlights(video_desc)
# Update highlights in real-time
highlight_types.value = f"#Highlights to search for: {highlights[:500] + '...' if len(highlights) > 500 else highlights}"
progress(0.4, desc="Detecting and extracting highlights...")
status.value = "Detecting and extracting highlights..."
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file:
temp_output = tmp_file.name
detector.create_highlight_video(video, temp_output)
# progress(0.9, desc="Adding watermark...")
# status.value = "Adding watermark..."
# output_path = temp_output.replace('.mp4', '_watermark.mp4')
# add_watermark(temp_output, output_path)
# os.unlink(temp_output)
progress(1.0, desc="Complete!")
return {
status: "Processing complete!",
video_description: video_description.value,
highlight_types: highlight_types.value,
output_video: gr.update(value=temp_output, visible=True)
}
except Exception as e:
return {
status: f"Error processing video: {str(e)}",
video_description: "",
highlight_types: "",
output_video: gr.update(visible=False)
}
process_btn.click(
on_process,
inputs=[input_video],
outputs=[status, video_description, highlight_types, output_video]
)
return app
# gr.Markdown("## Try It Yourself!")
# with gr.Row():
# input_video = gr.Video(
# label="Upload your video (max 20 minutes)",
# interactive=True
# )
# gr.Progress()
# process_btn = gr.Button("Process Video", variant="primary")
# status = gr.Markdown(visible=True)
# with gr.Row() as results_row:
# with gr.Column():
# video_description = gr.Markdown(visible=False)
# with gr.Column():
# highlight_types = gr.Markdown(visible=False)
# with gr.Row() as output_row:
# output_video = gr.Video(label="Highlight Video", visible=False)
# download_btn = gr.Button("Download Highlights", visible=False)
# def on_process(video, progress=gr.Progress()):
# if not video:
# return {
# status: "Please upload a video",
# video_description: gr.update(visible=False),
# highlight_types: gr.update(visible=False),
# output_video: gr.update(visible=False),
# download_btn: gr.update(visible=False)
# }
# status.value = "Processing video..."
# output_path, desc, highlights, err = process_video(video, progress=progress)
# if err:
# return {
# status: f"Error: {err}",
# video_description: gr.update(visible=False),
# highlight_types: gr.update(visible=False),
# output_video: gr.update(visible=False),
# download_btn: gr.update(visible=False)
# }
# return {
# status: "Processing complete!",
# video_description: gr.update(value=desc, visible=True),
# highlight_types: gr.update(value=highlights, visible=True),
# output_video: gr.update(value=output_path, visible=True),
# download_btn: gr.update(visible=True)
# }
# process_btn.click(
# on_process,
# inputs=[input_video],
# outputs=[status, video_description, highlight_types, output_video, download_btn]
# )
# download_btn.click(
# lambda x: x,
# inputs=[output_video],
# outputs=[output_video]
# )
# return app
if __name__ == "__main__":
# Initialize CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
zero = torch.Tensor([0]).to(device)
app = create_ui("video_spec.json")
app.launch()