Spaces:
Running
on
A100
Running
on
A100
import os | |
import json | |
import gradio as gr | |
import torch | |
import spaces | |
import tempfile | |
from pathlib import Path | |
import subprocess | |
import logging | |
import xml.etree.ElementTree as ET | |
from xml.dom import minidom | |
from transformers import AutoProcessor, AutoModelForVision2Seq | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def load_examples(json_path: str) -> dict: | |
with open(json_path, 'r') as f: | |
return json.load(f) | |
def format_duration(seconds: float) -> str: | |
hours = int(seconds // 3600) | |
minutes = int((seconds % 3600) // 60) | |
secs = int(seconds % 60) | |
return f"{hours:02d}:{minutes:02d}:{secs:02d}" | |
def get_video_duration_seconds(video_path: str) -> float: | |
cmd = [ | |
"ffprobe", | |
"-v", "quiet", | |
"-print_format", "json", | |
"-show_format", | |
video_path | |
] | |
result = subprocess.run(cmd, capture_output=True, text=True) | |
info = json.loads(result.stdout) | |
return float(info["format"]["duration"]) | |
class VideoHighlightDetector: | |
def __init__( | |
self, | |
model_path: str, | |
device: str = "cuda", | |
batch_size: int = 8 | |
): | |
self.device = device | |
self.batch_size = batch_size | |
# Initialize model and processor | |
self.processor = AutoProcessor.from_pretrained(model_path) | |
self.model = AutoModelForVision2Seq.from_pretrained( | |
model_path, | |
torch_dtype=torch.bfloat16 | |
).to(device) | |
def analyze_video_content(self, video_path: str) -> str: | |
system_message = "You are a helpful assistant that can understand videos. Describe what type of video this is and what's happening in it." | |
messages = [ | |
{ | |
"role": "system", | |
"content": [{"type": "text", "text": system_message}] | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "video", "path": video_path}, | |
{"type": "text", "text": "What type of video is this and what's happening in it? Be specific about the content type and general activities you observe."} | |
] | |
} | |
] | |
inputs = self.processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt" | |
).to(self.device) | |
outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7) | |
return self.processor.decode(outputs[0], skip_special_tokens=True).lower().split("assistant: ")[1] | |
def analyze_segment(self, video_path: str) -> str: | |
"""Analyze a specific video segment and provide a brief description.""" | |
messages = [ | |
{ | |
"role": "system", | |
"content": [{"type": "text", "text": "Describe what is happening in this specific video segment in a brief, concise way."}] | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "video", "path": video_path}, | |
{"type": "text", "text": "What is happening in this segment? Provide a very brief and concise description."} | |
] | |
} | |
] | |
inputs = self.processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt" | |
).to(self.device) | |
outputs = self.model.generate(**inputs, max_new_tokens=128, do_sample=True, temperature=0.7) | |
return self.processor.decode(outputs[0], skip_special_tokens=True).split("Assistant: ")[1] | |
def determine_highlights(self, video_description: str) -> str: | |
messages = [ | |
{ | |
"role": "system", | |
"content": [{"type": "text", "text": "You are a professional video editor specializing in creating viral highlight reels."}] | |
}, | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": f"Based on this description, list which segments should be included in highlights: {video_description}"}] | |
} | |
] | |
inputs = self.processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt" | |
).to(self.device) | |
outputs = self.model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7) | |
return self.processor.decode(outputs[0], skip_special_tokens=True).split("Assistant: ")[1] | |
def process_segment(self, video_path: str, highlight_types: str) -> bool: | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "video", "path": video_path}, | |
{"type": "text", "text": f"Do you see any of these elements in the video: {highlight_types}? Answer yes or no."} | |
] | |
} | |
] | |
inputs = self.processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt" | |
).to(self.device) | |
outputs = self.model.generate(**inputs, max_new_tokens=64, do_sample=False) | |
response = self.processor.decode(outputs[0], skip_special_tokens=True).lower().split("assistant: ")[1] | |
return "yes" in response | |
def create_xspf_playlist(video_path: str, segments: list, descriptions: list) -> str: | |
"""Create XSPF playlist from segments with descriptions.""" | |
# Define namespaces | |
XSPF_NS = "http://xspf.org/ns/0/" | |
VLC_NS = "http://www.videolan.org/vlc/playlist/ns/0/" | |
# Create the root element with proper namespace | |
root = ET.Element("playlist", { | |
"xmlns": XSPF_NS, | |
"xmlns:vlc": VLC_NS, | |
"version": "1" | |
}) | |
# Get video filename for the title | |
video_filename = os.path.basename(video_path) | |
title = ET.SubElement(root, "title") | |
title.text = f"{video_filename} - Highlights" | |
tracklist = ET.SubElement(root, "trackList") | |
for idx, ((start_time, end_time), description) in enumerate(zip(segments, descriptions)): | |
track = ET.SubElement(tracklist, "track") | |
location = ET.SubElement(track, "location") | |
location.text = f"file:///{video_filename}" | |
title = ET.SubElement(track, "title") | |
title.text = f"Highlight {idx + 1}: {description}" | |
annotation = ET.SubElement(track, "annotation") | |
annotation.text = description | |
start_meta = ET.SubElement(track, "meta") | |
start_meta.set("rel", "start") | |
start_meta.text = format_duration(start_time) | |
end_meta = ET.SubElement(track, "meta") | |
end_meta.set("rel", "end") | |
end_meta.text = format_duration(end_time) | |
# Add VLC extension | |
extension = ET.SubElement(root, "extension") | |
extension.set("application", "http://www.videolan.org/vlc/playlist/0") | |
for i in range(len(segments)): | |
item = ET.SubElement(extension, "{%s}item" % VLC_NS) | |
item.set("tid", str(i)) | |
# Convert to string with pretty printing | |
xml_str = minidom.parseString(ET.tostring(root, encoding='unicode')).toprettyxml(indent=" ") | |
# Clean up any potential namespace declaration issues | |
xml_str = xml_str.replace('xmlns:ns0="http://www.videolan.org/vlc/playlist/ns/0/"', '') | |
xml_str = xml_str.replace('ns0:', 'vlc:') | |
return xml_str | |
def create_ui(examples_path: str, model_path: str): | |
examples_data = load_examples(examples_path) | |
with gr.Blocks() as app: | |
gr.Markdown("# Video Highlight Playlist Generator") | |
gr.Markdown("Upload a video and get an XSPF playlist of highlights!") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_video = gr.Video( | |
label="Upload your video (max 30 minutes)", | |
interactive=True | |
) | |
process_btn = gr.Button("Process Video", variant="primary") | |
with gr.Column(scale=1): | |
output_playlist = gr.File( | |
label="Highlight Playlist (XSPF)", | |
visible=False, | |
interactive=False, | |
) | |
status = gr.Markdown() | |
analysis_accordion = gr.Accordion( | |
"Analysis Details", | |
open=True, | |
visible=False | |
) | |
with analysis_accordion: | |
video_description = gr.Markdown("") | |
highlight_types = gr.Markdown("") | |
def on_process(video): | |
if not video: | |
return [ | |
None, | |
"Please upload a video", | |
"", | |
"", | |
gr.update(visible=False) | |
] | |
try: | |
duration = get_video_duration_seconds(video) | |
if duration > 1800: # 30 minutes | |
return [ | |
None, | |
"Video must be shorter than 30 minutes", | |
"", | |
"", | |
gr.update(visible=False) | |
] | |
detector = VideoHighlightDetector(model_path=model_path) | |
# Analyze video content | |
video_desc = detector.analyze_video_content(video) | |
formatted_desc = f"### Video Summary:\n{video_desc}" | |
# Determine highlight types | |
highlights = detector.determine_highlights(video_desc) | |
formatted_highlights = f"### Highlight Criteria:\n{highlights}" | |
# Process video in segments | |
segment_length = 10.0 | |
kept_segments = [] | |
segment_descriptions = [] | |
for start_time in range(0, int(duration), int(segment_length)): | |
end_time = min(start_time + segment_length, duration) | |
# Create temporary segment | |
with tempfile.NamedTemporaryFile(suffix='.mp4') as temp_segment: | |
cmd = [ | |
"ffmpeg", | |
"-y", | |
"-i", video, | |
"-ss", str(start_time), | |
"-t", str(segment_length), | |
"-c:v", "libx264", | |
"-preset", "ultrafast", | |
temp_segment.name | |
] | |
subprocess.run(cmd, check=True) | |
if detector.process_segment(temp_segment.name, highlights): | |
# Get segment description | |
description = detector.analyze_segment(temp_segment.name) | |
kept_segments.append((start_time, end_time)) | |
segment_descriptions.append(description) | |
if kept_segments: | |
# Create XSPF playlist | |
playlist_content = create_xspf_playlist(video, kept_segments, segment_descriptions) | |
# Save playlist to temporary file | |
with tempfile.NamedTemporaryFile(mode='w', suffix='.xspf', delete=False) as f: | |
f.write(playlist_content) | |
playlist_path = f.name | |
return [ | |
gr.update(value=playlist_path, visible=True), | |
"Processing complete! Download the XSPF playlist.", | |
formatted_desc, | |
formatted_highlights, | |
gr.update(visible=True) | |
] | |
else: | |
return [ | |
None, | |
"No highlights detected in the video.", | |
formatted_desc, | |
formatted_highlights, | |
gr.update(visible=True) | |
] | |
except Exception as e: | |
logger.exception("Error processing video") | |
return [ | |
None, | |
f"Error processing video: {str(e)}", | |
"", | |
"", | |
gr.update(visible=False) | |
] | |
finally: | |
torch.cuda.empty_cache() | |
process_btn.click( | |
on_process, | |
inputs=[input_video], | |
outputs=[ | |
output_playlist, | |
status, | |
video_description, | |
highlight_types, | |
analysis_accordion | |
], | |
queue=True, | |
) | |
return app | |
if __name__ == "__main__": | |
app = create_ui("video_spec.json", "HuggingFaceTB/SmolVLM2-2.2B-Instruct") | |
app.launch() |