Spaces:
Running
on
A100
Running
on
A100
import os | |
import json | |
import gradio as gr | |
import tempfile | |
from PIL import Image, ImageDraw, ImageFont | |
import cv2 | |
from typing import Tuple, Optional | |
import torch | |
from pathlib import Path | |
import time | |
import torch | |
import spaces | |
import os | |
from video_highlight_detector import ( | |
load_model, | |
BatchedVideoHighlightDetector, | |
get_video_duration_seconds | |
) | |
def load_examples(json_path: str) -> dict: | |
with open(json_path, 'r') as f: | |
return json.load(f) | |
def format_duration(seconds: int) -> str: | |
hours = seconds // 3600 | |
minutes = (seconds % 3600) // 60 | |
secs = seconds % 60 | |
if hours > 0: | |
return f"{hours}:{minutes:02d}:{secs:02d}" | |
return f"{minutes}:{secs:02d}" | |
def create_ui(examples_path: str): | |
examples_data = load_examples(examples_path) | |
with gr.Blocks() as app: | |
gr.Markdown("# Video Highlight Generator") | |
gr.Markdown("Upload a video and get an automated highlight reel!") | |
with gr.Row(): | |
gr.Markdown("## Example Results") | |
with gr.Row(): | |
for example in examples_data["examples"]: | |
with gr.Column(): | |
gr.Video( | |
value=example["original"]["url"], | |
label=f"Original ({format_duration(example['original']['duration_seconds'])})", | |
interactive=False | |
) | |
gr.Markdown(f"### {example['title']}") | |
with gr.Column(): | |
gr.Video( | |
value=example["highlights"]["url"], | |
label=f"Highlights ({format_duration(example['highlights']['duration_seconds'])})", | |
interactive=False | |
) | |
with gr.Accordion("Model chain of thought details", open=False): | |
gr.Markdown(f"#Summary: {example['analysis']['video_description']}") | |
gr.Markdown(f"#Highlights to search for: {example['analysis']['highlight_types']}") | |
gr.Markdown("## Try It Yourself!") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_video = gr.Video( | |
label="Upload your video (max 20 minutes)", | |
interactive=True | |
) | |
process_btn = gr.Button("Process Video", variant="primary") | |
with gr.Column(scale=1): | |
output_video = gr.Video( | |
label="Highlight Video", | |
visible=False, | |
interactive=False, | |
) | |
status = gr.Markdown() | |
analysis_accordion = gr.Accordion( | |
"Model chain of thought details", | |
open=True, | |
visible=False | |
) | |
with analysis_accordion: | |
video_description = gr.Markdown("", elem_id="video_desc") | |
highlight_types = gr.Markdown("", elem_id="highlight_types") | |
def process_video(video): | |
if not video: | |
return [ | |
"Please upload a video", | |
"", | |
"", | |
None, | |
False | |
] | |
try: | |
duration = get_video_duration_seconds(video) | |
if duration > 1200: # 20 minutes | |
return [ | |
"Video must be shorter than 20 minutes", | |
"", | |
"", | |
None, | |
False | |
] | |
# Load model | |
model, processor = load_model() | |
detector = BatchedVideoHighlightDetector(model, processor, batch_size=8) | |
# Analyze content | |
video_desc = detector.analyze_video_content(video) | |
formatted_desc = f"#Summary: {video_desc[:500] + '...' if len(video_desc) > 500 else video_desc}" | |
# Determine highlights | |
highlights = detector.determine_highlights(video_desc) | |
formatted_highlights = f"#Highlights to search for: {highlights[:500] + '...' if len(highlights) > 500 else highlights}" | |
# Create highlight video | |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: | |
temp_output = tmp_file.name | |
detector.create_highlight_video(video, temp_output) | |
return [ | |
"Processing complete!", | |
formatted_desc, | |
formatted_highlights, | |
temp_output, | |
True | |
] | |
except Exception as e: | |
return [ | |
f"Error processing video: {str(e)}", | |
"", | |
"", | |
None, | |
False | |
] | |
def process_with_updates(video): | |
# Initial state | |
yield [ | |
"Loading model...", | |
"", | |
"", | |
None, | |
True # Show accordion | |
] | |
# Analyzing video | |
yield [ | |
"Analyzing video content...", | |
"", | |
"", | |
None, | |
True | |
] | |
# Get final results | |
results = process_video(video) | |
# If we're still processing, show an intermediate state | |
if results[0] != "Processing complete!": | |
yield [ | |
"Detecting and extracting highlights...", | |
results[1], # description | |
results[2], # highlights | |
None, | |
True | |
] | |
# Return final state | |
yield results | |
process_btn.click( | |
process_with_updates, | |
inputs=[input_video], | |
outputs=[ | |
status, | |
video_description, | |
highlight_types, | |
output_video, | |
analysis_accordion | |
] | |
) | |
return app | |
# gr.Markdown("## Try It Yourself!") | |
# with gr.Row(): | |
# with gr.Column(scale=1): | |
# input_video = gr.Video( | |
# label="Upload your video (max 20 minutes)", | |
# interactive=True | |
# ) | |
# process_btn = gr.Button("Process Video", variant="primary") | |
# with gr.Column(scale=1): | |
# output_video = gr.Video( | |
# label="Highlight Video", | |
# visible=False, | |
# interactive=False, | |
# ) | |
# status = gr.Markdown() | |
# analysis_accordion = gr.Accordion( | |
# "Model chain of thought details", | |
# open=True, | |
# visible=False | |
# ) | |
# with analysis_accordion: | |
# video_description = gr.Markdown("", elem_id="video_desc") | |
# highlight_types = gr.Markdown("", elem_id="highlight_types") | |
# @spaces.GPU | |
# def on_process(video): | |
# if not video: | |
# return { | |
# status: "Please upload a video", | |
# video_description: "", | |
# highlight_types: "", | |
# output_video: gr.update(visible=False), | |
# analysis_accordion: gr.update(visible=False) | |
# } | |
# try: | |
# duration = get_video_duration_seconds(video) | |
# if duration > 1200: # 20 minutes | |
# return { | |
# status: "Video must be shorter than 20 minutes", | |
# video_description: "", | |
# highlight_types: "", | |
# output_video: gr.update(visible=False), | |
# analysis_accordion: gr.update(visible=False) | |
# } | |
# # Make accordion visible as soon as processing starts | |
# yield { | |
# status: "Loading model...", | |
# video_description: "", | |
# highlight_types: "", | |
# output_video: gr.update(visible=False), | |
# analysis_accordion: gr.update(visible=True) | |
# } | |
# model, processor = load_model() | |
# detector = BatchedVideoHighlightDetector(model, processor, batch_size=8) | |
# yield { | |
# status: "Analyzing video content...", | |
# video_description: "", | |
# highlight_types: "", | |
# output_video: gr.update(visible=False), | |
# analysis_accordion: gr.update(visible=True) | |
# } | |
# video_desc = detector.analyze_video_content(video) | |
# formatted_desc = f"#Summary: {video_desc[:500] + '...' if len(video_desc) > 500 else video_desc}" | |
# # Update description as soon as it's available | |
# yield { | |
# status: "Determining highlight types...", | |
# video_description: formatted_desc, | |
# highlight_types: "", | |
# output_video: gr.update(visible=False), | |
# analysis_accordion: gr.update(visible=True) | |
# } | |
# highlights = detector.determine_highlights(video_desc) | |
# formatted_highlights = f"#Highlights to search for: {highlights[:500] + '...' if len(highlights) > 500 else highlights}" | |
# # Update highlights as soon as they're available | |
# yield { | |
# status: "Detecting and extracting highlights...", | |
# video_description: formatted_desc, | |
# highlight_types: formatted_highlights, | |
# output_video: gr.update(visible=False), | |
# analysis_accordion: gr.update(visible=True) | |
# } | |
# with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: | |
# temp_output = tmp_file.name | |
# detector.create_highlight_video(video, temp_output) | |
# return { | |
# status: "Processing complete!", | |
# video_description: formatted_desc, | |
# highlight_types: formatted_highlights, | |
# output_video: gr.update(value=temp_output, visible=True), | |
# analysis_accordion: gr.update(visible=True) | |
# } | |
# except Exception as e: | |
# return { | |
# status: f"Error processing video: {str(e)}", | |
# video_description: "", | |
# highlight_types: "", | |
# output_video: gr.update(visible=False), | |
# analysis_accordion: gr.update(visible=False) | |
# } | |
# process_btn.click( | |
# on_process, | |
# inputs=[input_video], | |
# outputs=[status, video_description, highlight_types, output_video, analysis_accordion] | |
# ) | |
# return app | |
if __name__ == "__main__": | |
# Initialize CUDA | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
zero = torch.Tensor([0]).to(device) | |
app = create_ui("video_spec.json") | |
app.launch() |