|
import os |
|
import json |
|
import gradio as gr |
|
import tempfile |
|
import torch |
|
import spaces |
|
from pathlib import Path |
|
from transformers import AutoProcessor, AutoModelForVision2Seq |
|
import subprocess |
|
import logging |
|
import xml.etree.ElementTree as ET |
|
from xml.dom import minidom |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
def load_examples(json_path: str) -> dict: |
|
with open(json_path, 'r') as f: |
|
return json.load(f) |
|
|
|
def format_duration(seconds: float) -> str: |
|
hours = int(seconds // 3600) |
|
minutes = int((seconds % 3600) // 60) |
|
secs = int(seconds % 60) |
|
return f"{hours:02d}:{minutes:02d}:{secs:02d}" |
|
|
|
def get_video_duration_seconds(video_path: str) -> float: |
|
cmd = [ |
|
"ffprobe", |
|
"-v", "quiet", |
|
"-print_format", "json", |
|
"-show_format", |
|
video_path |
|
] |
|
result = subprocess.run(cmd, capture_output=True, text=True) |
|
info = json.loads(result.stdout) |
|
return float(info["format"]["duration"]) |
|
|
|
class VideoHighlightDetector: |
|
def __init__( |
|
self, |
|
model_path: str, |
|
device: str = "cuda", |
|
batch_size: int = 8 |
|
): |
|
self.device = device |
|
self.batch_size = batch_size |
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained(model_path) |
|
self.model = AutoModelForVision2Seq.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16 |
|
).to(device) |
|
|
|
def analyze_video_content(self, video_path: str) -> str: |
|
system_message = "You are a helpful assistant that can understand videos. Describe what type of video this is and what's happening in it." |
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": [{"type": "text", "text": system_message}] |
|
}, |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "video", "path": video_path}, |
|
{"type": "text", "text": "What type of video is this and what's happening in it? Be specific about the content type and general activities you observe."} |
|
] |
|
} |
|
] |
|
|
|
inputs = self.processor.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
outputs = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7) |
|
return self.processor.decode(outputs[0], skip_special_tokens=True).lower().split("assistant: ")[1] |
|
|
|
def analyze_segment(self, video_path: str) -> str: |
|
"""Analyze a specific video segment and provide a brief description.""" |
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": [{"type": "text", "text": "Describe what is happening in this specific video segment in a brief, concise way."}] |
|
}, |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "video", "path": video_path}, |
|
{"type": "text", "text": "What is happening in this segment? Provide a brief description."} |
|
] |
|
} |
|
] |
|
|
|
inputs = self.processor.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
outputs = self.model.generate(**inputs, max_new_tokens=128, do_sample=True, temperature=0.7) |
|
return self.processor.decode(outputs[0], skip_special_tokens=True).split("Assistant: ")[1] |
|
|
|
def determine_highlights(self, video_description: str) -> str: |
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": [{"type": "text", "text": "You are a professional video editor specializing in creating viral highlight reels."}] |
|
}, |
|
{ |
|
"role": "user", |
|
"content": [{"type": "text", "text": f"Based on this description, list which segments should be included in highlights: {video_description}"}] |
|
} |
|
] |
|
|
|
inputs = self.processor.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
outputs = self.model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7) |
|
return self.processor.decode(outputs[0], skip_special_tokens=True).split("Assistant: ")[1] |
|
|
|
def process_segment(self, video_path: str, highlight_types: str) -> bool: |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "video", "path": video_path}, |
|
{"type": "text", "text": f"Do you see any of these elements in the video: {highlight_types}? Answer yes or no."} |
|
] |
|
} |
|
] |
|
|
|
inputs = self.processor.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
outputs = self.model.generate(**inputs, max_new_tokens=64, do_sample=False) |
|
response = self.processor.decode(outputs[0], skip_special_tokens=True).lower().split("assistant: ")[1] |
|
return "yes" in response |
|
def create_xspf_playlist(video_path: str, segments: list, descriptions: list) -> str: |
|
"""Create XSPF playlist from segments with descriptions.""" |
|
|
|
ET.register_namespace('vlc', 'http://www.videolan.org/vlc/playlist/ns/0/') |
|
ET.register_namespace('', 'http://xspf.org/ns/0/') |
|
|
|
root = ET.Element("{http://xspf.org/ns/0/}playlist", {"version": "1"}) |
|
|
|
|
|
video_filename = os.path.basename(video_path) |
|
title = ET.SubElement(root, "{http://xspf.org/ns/0/}title") |
|
title.text = f"{video_filename} - Highlights" |
|
|
|
tracklist = ET.SubElement(root, "{http://xspf.org/ns/0/}trackList") |
|
|
|
for idx, ((start_time, end_time), description) in enumerate(zip(segments, descriptions)): |
|
track = ET.SubElement(tracklist, "{http://xspf.org/ns/0/}track") |
|
|
|
location = ET.SubElement(track, "{http://xspf.org/ns/0/}location") |
|
location.text = f"file:///{video_filename}" |
|
|
|
title = ET.SubElement(track, "{http://xspf.org/ns/0/}title") |
|
title.text = f"Highlight {idx + 1}" |
|
|
|
annotation = ET.SubElement(track, "{http://xspf.org/ns/0/}annotation") |
|
annotation.text = description |
|
|
|
start_meta = ET.SubElement(track, "{http://xspf.org/ns/0/}meta", {"rel": "start"}) |
|
start_meta.text = format_duration(start_time) |
|
|
|
end_meta = ET.SubElement(track, "{http://xspf.org/ns/0/}meta", {"rel": "end"}) |
|
end_meta.text = format_duration(end_time) |
|
|
|
|
|
extension = ET.SubElement(root, "{http://xspf.org/ns/0/}extension", |
|
{"application": "http://www.videolan.org/vlc/playlist/0"}) |
|
|
|
for i in range(len(segments)): |
|
ET.SubElement(extension, "{http://www.videolan.org/vlc/playlist/ns/0/}item", |
|
{"tid": str(i)}) |
|
|
|
|
|
xml_str = minidom.parseString(ET.tostring(root, encoding='unicode')).toprettyxml(indent=" ") |
|
return xml_str |
|
|
|
def create_ui(examples_path: str, model_path: str): |
|
examples_data = load_examples(examples_path) |
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("# Video Highlight Playlist Generator") |
|
gr.Markdown("Upload a video and get an XSPF playlist of highlights!") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_video = gr.Video( |
|
label="Upload your video (max 30 minutes)", |
|
interactive=True |
|
) |
|
process_btn = gr.Button("Process Video", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
output_playlist = gr.File( |
|
label="Highlight Playlist (XSPF)", |
|
visible=False, |
|
interactive=False, |
|
) |
|
status = gr.Markdown() |
|
|
|
analysis_accordion = gr.Accordion( |
|
"Analysis Details", |
|
open=True, |
|
visible=False |
|
) |
|
|
|
with analysis_accordion: |
|
video_description = gr.Markdown("") |
|
highlight_types = gr.Markdown("") |
|
|
|
@spaces.GPU |
|
def on_process(video): |
|
if not video: |
|
return [ |
|
None, |
|
"Please upload a video", |
|
"", |
|
"", |
|
gr.update(visible=False) |
|
] |
|
|
|
try: |
|
duration = get_video_duration_seconds(video) |
|
if duration > 18000: |
|
return [ |
|
None, |
|
"Video must be shorter than 30 minutes", |
|
"", |
|
"", |
|
gr.update(visible=False) |
|
] |
|
|
|
detector = VideoHighlightDetector(model_path=model_path) |
|
|
|
|
|
video_desc = detector.analyze_video_content(video) |
|
formatted_desc = f"### Video Summary:\n{video_desc}" |
|
|
|
|
|
highlights = detector.determine_highlights(video_desc) |
|
formatted_highlights = f"### Highlight Criteria:\n{highlights}" |
|
|
|
|
|
segment_length = 10.0 |
|
kept_segments = [] |
|
segment_descriptions = [] |
|
|
|
for start_time in range(0, int(duration), int(segment_length)): |
|
end_time = min(start_time + segment_length, duration) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.mp4') as temp_segment: |
|
cmd = [ |
|
"ffmpeg", |
|
"-y", |
|
"-i", video, |
|
"-ss", str(start_time), |
|
"-t", str(segment_length), |
|
"-c:v", "libx264", |
|
"-preset", "ultrafast", |
|
temp_segment.name |
|
] |
|
subprocess.run(cmd, check=True) |
|
|
|
if detector.process_segment(temp_segment.name, highlights): |
|
|
|
print("KEEPING SEGMENT") |
|
description = detector.analyze_segment(temp_segment.name) |
|
kept_segments.append((start_time, end_time)) |
|
segment_descriptions.append(description) |
|
|
|
if kept_segments: |
|
|
|
playlist_content = create_xspf_playlist(video, kept_segments, segment_descriptions) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.xspf', delete=False) as f: |
|
f.write(playlist_content) |
|
playlist_path = f.name |
|
|
|
return [ |
|
gr.update(value=playlist_path, visible=True), |
|
"Processing complete! Download the XSPF playlist.", |
|
formatted_desc, |
|
formatted_highlights, |
|
gr.update(visible=True) |
|
] |
|
else: |
|
return [ |
|
None, |
|
"No highlights detected in the video.", |
|
formatted_desc, |
|
formatted_highlights, |
|
gr.update(visible=True) |
|
] |
|
|
|
except Exception as e: |
|
logger.exception("Error processing video") |
|
return [ |
|
None, |
|
f"Error processing video: {str(e)}", |
|
"", |
|
"", |
|
gr.update(visible=False) |
|
] |
|
finally: |
|
torch.cuda.empty_cache() |
|
|
|
process_btn.click( |
|
on_process, |
|
inputs=[input_video], |
|
outputs=[ |
|
output_playlist, |
|
status, |
|
video_description, |
|
highlight_types, |
|
analysis_accordion |
|
], |
|
queue=True, |
|
) |
|
|
|
return app |
|
|
|
if __name__ == "__main__": |
|
app = create_ui("video_spec.json", "HuggingFaceTB/SmolVLM2-2.2B-Instruct") |
|
app.launch() |