Spaces:
Running
Running
from __future__ import annotations | |
import logging | |
import os | |
import re | |
import shutil | |
from pathlib import Path | |
from typing import Optional | |
import cv2 | |
import yt_dlp | |
from llama_index.core.agent.workflow import FunctionAgent | |
from llama_index.core.base.llms.types import TextBlock, ImageBlock, ChatMessage | |
from llama_index.core.tools import FunctionTool | |
from llama_index.llms.google_genai import GoogleGenAI | |
from tqdm import tqdm | |
from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound | |
# --------------------------------------------------------------------------- | |
# Environment setup & logging | |
# --------------------------------------------------------------------------- | |
logger = logging.getLogger(__name__) | |
# --------------------------------------------------------------------------- | |
# Prompt loader | |
# --------------------------------------------------------------------------- | |
def load_prompt_from_file(filename: str = "../prompts/video_analyzer_prompt.txt") -> str: | |
"""Load the system prompt for video analysis from *filename*. | |
Falls back to a minimal prompt if the file cannot be read. | |
""" | |
script_dir = Path(__file__).parent | |
prompt_path = (script_dir / filename).resolve() | |
try: | |
with prompt_path.open("r", encoding="utf-8") as fp: | |
prompt = fp.read() | |
logger.info("Successfully loaded system prompt from %s", prompt_path) | |
return prompt | |
except FileNotFoundError: | |
logger.error( | |
"Prompt file %s not found. Using fallback prompt.", prompt_path | |
) | |
except Exception as exc: # pylint: disable=broad-except | |
logger.error( | |
"Error loading prompt file %s: %s", prompt_path, exc, exc_info=True | |
) | |
# Fallback – keep it extremely short to save tokens | |
return ( | |
"You are a video analyzer. Provide a factual, chronological " | |
"description of the video, identify key events, and summarise insights." | |
) | |
def extract_frames(video_path, output_dir, fps=1/2): | |
""" | |
Extract frames from video at specified FPS | |
Returns a list of (frame_path, timestamp) tuples | |
""" | |
os.makedirs(output_dir, exist_ok=True) | |
# Open video | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print(f"Error: Could not open video {video_path}") | |
return [], None | |
# Get video properties | |
video_fps = cap.get(cv2.CAP_PROP_FPS) | |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
duration = frame_count / video_fps | |
# Calculate frame interval | |
interval = int(video_fps / fps) | |
if interval < 1: | |
interval = 1 | |
# Extract frames | |
frames = [] | |
frame_idx = 0 | |
with tqdm(total=frame_count, desc="Extracting frames") as pbar: | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if frame_idx % interval == 0: | |
timestamp = frame_idx / video_fps | |
frame_path = os.path.join(output_dir, f"frame_{frame_idx:06d}.jpg") | |
cv2.imwrite(frame_path, frame) | |
frames.append((frame_path, timestamp)) | |
frame_idx += 1 | |
pbar.update(1) | |
cap.release() | |
return frames, duration | |
def download_video_and_analyze(video_url: str) -> str: | |
"""Download a video from *video_url* and return the local file path.""" | |
llm_model_name = os.getenv("VIDEO_ANALYZER_LLM_MODEL", "models/gemini-1.5-pro") | |
gemini_api_key = os.getenv("GEMINI_API_KEY") | |
cookie_txt = "yt_cookie.txt" | |
ydl_opts = { | |
'format': 'best', | |
'outtmpl': os.path.join("downloaded_videos", 'temp_video.%(ext)s'), | |
'quiet': True, | |
'extract_flat': True, | |
'ignoreerrors': True, | |
"cookiefile": cookie_txt | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl_download: | |
ydl_download.download(video_url) | |
print(f"Processing video: {video_url}") | |
# Create temporary directory for frames | |
temp_dir = "frame_downloaded_videos" | |
os.makedirs(temp_dir, exist_ok=True) | |
# Extract frames | |
frames, duration = extract_frames(os.path.join("downloaded_videos", 'temp_video.mp4'), temp_dir) | |
if not frames: | |
logging.info(f"No frames extracted from {video_url}") | |
return f"No frames extracted from {video_url}" | |
blocks = [] | |
text_block = TextBlock(text=load_prompt_from_file()) | |
blocks.append(text_block) | |
for frame_path, timestamp in tqdm(frames, desc="Collecting frames"): | |
blocks.append(ImageBlock(path=frame_path)) | |
llm = GoogleGenAI(api_key=gemini_api_key, model=llm_model_name) | |
logger.info("Using LLM model: %s", llm_model_name) | |
response = llm.chat([ChatMessage(role="user", blocks=blocks)]) | |
# Clean up temporary files | |
shutil.rmtree(temp_dir) | |
os.remove(os.path.join("downloaded_videos", 'temp_video.mp4')) | |
return response.message.content | |
# --- Helper function to extract YouTube Video ID --- | |
def extract_video_id(url: str) -> Optional[str]: | |
"""Extracts the YouTube video ID from various URL formats.""" | |
# Standard watch URL: https://www.youtube.com/watch?v=VIDEO_ID | |
pattern = re.compile( | |
r'^(?:https?://)?' # protocole optionnel | |
r'(?:www\.)?' # sous-domaine optionnel | |
r'youtube\.com/watch\?' # domaine et chemin fixe | |
r'(?:.*&)?' # éventuellement d'autres paramètres avant v= | |
r'v=([^&]+)' # capture de l'ID (tout jusqu'au prochain & ou fin) | |
) | |
match = pattern.search(url) | |
if match: | |
video_id = match.group(1) | |
return video_id # affiche "VIDEO_ID" | |
else: | |
print("Aucun ID trouvé") | |
return None | |
# --- YouTube Transcript Tool --- | |
def get_youtube_transcript(video_url_or_id: str, languages: str | None = None) -> str: | |
"""Fetches the transcript for a YouTube video using its URL or video ID. | |
Specify preferred languages as a list (e.g., ["en", "es"]). | |
Returns the transcript text or an error message. | |
""" | |
if languages is None: | |
languages = ["en"] | |
logger.info(f"Attempting to fetch YouTube transcript for: {video_url_or_id}") | |
video_id = extract_video_id(video_url_or_id) | |
if video_id is None or not video_id: | |
logger.error(f"Could not extract video ID from: {video_url_or_id}") | |
return f"Error: Invalid YouTube URL or Video ID format: {video_url_or_id}" | |
try: | |
# Fetch available transcripts | |
api = YouTubeTranscriptApi() | |
transcript_list = api.list(video_id) | |
# Try to find a transcript in the specified languages | |
transcript = transcript_list.find_transcript(languages) | |
# Fetch the actual transcript data (list of dicts) | |
transcript_data = transcript.fetch() | |
# Combine the text parts into a single string | |
full_transcript = " ".join(snippet.text for snippet in transcript_data) | |
full_transcript = " ".join(snippet.text for snippet in transcript_data) | |
logger.info(f"Successfully fetched transcript for video ID {video_id} in language {transcript.language}.") | |
return full_transcript | |
except TranscriptsDisabled: | |
logger.warning(f"Transcripts are disabled for video ID: {video_id}") | |
return f"Error: Transcripts are disabled for this video (ID: {video_id})." | |
except NoTranscriptFound as e: | |
logger.warning( | |
f"No transcript found for video ID {video_id} in languages {languages}. Available: {e.available_transcripts}") | |
# Try fetching any available transcript if specific languages failed | |
try: | |
logger.info(f"Attempting to fetch any available transcript for {video_id}") | |
any_transcript = transcript_list.find_generated_transcript( | |
transcript_list.manually_created_transcripts.keys() or transcript_list.generated_transcripts.keys()) | |
any_transcript_data = any_transcript.fetch() | |
full_transcript = " ".join([item["text"] for item in any_transcript_data]) | |
logger.info( | |
f"Successfully fetched fallback transcript for video ID {video_id} in language {any_transcript.language}.") | |
return full_transcript | |
except Exception as fallback_e: | |
logger.error( | |
f"Could not find any transcript for video ID {video_id}. Original error: {e}. Fallback error: {fallback_e}") | |
return f"Error: No transcript found for video ID {video_id} in languages {languages} or any fallback language." | |
except Exception as e: | |
logger.error(f"Unexpected error fetching transcript for video ID {video_id}: {e}", exc_info=True) | |
return f"Error fetching transcript: {e}" | |
download_video_and_analyze_tool = FunctionTool.from_defaults( | |
name="download_video_and_analyze", | |
description=( | |
"Downloads a video (YouTube or direct URL), samples representative frames, " | |
"and feeds them to Gemini for multimodal analysis—returning a rich textual summary " | |
"of the visual content." | |
), | |
fn=download_video_and_analyze, | |
) | |
youtube_transcript_tool = FunctionTool.from_defaults( | |
fn=get_youtube_transcript, | |
name="get_youtube_transcript", | |
description=( | |
"(YouTube) Fetches the transcript text for a given YouTube video URL or video ID. " | |
"Specify preferred languages (e.g., 'en', 'es'). Returns transcript or error." | |
) | |
) | |
# --------------------------------------------------------------------------- | |
# Agent factory | |
# --------------------------------------------------------------------------- | |
def initialize_video_analyzer_agent() -> FunctionAgent: | |
"""Initialise and return a *video_analyzer_agent* `FunctionAgent`.""" | |
logger.info("Initialising VideoAnalyzerAgent …") | |
llm_model_name = os.getenv("VIDEO_ANALYZER_LLM_MODEL", "models/gemini-1.5-pro") | |
gemini_api_key = os.getenv("GEMINI_API_KEY") | |
if not gemini_api_key: | |
logger.error("GEMINI_API_KEY not found in environment variables.") | |
raise ValueError("GEMINI_API_KEY must be set") | |
try: | |
llm = GoogleGenAI(api_key=gemini_api_key, model=llm_model_name) | |
logger.info("Using LLM model: %s", llm_model_name) | |
system_prompt = load_prompt_from_file() | |
tools = [download_video_and_analyze_tool, youtube_transcript_tool] | |
agent = FunctionAgent( | |
name="video_analyzer_agent", | |
description=( | |
"VideoAnalyzerAgent inspects video files using Gemini's multimodal " | |
"video understanding capabilities, producing factual scene analysis, " | |
"temporal segmentation, and concise summaries as guided by the system " | |
"prompt." | |
), | |
llm=llm, | |
system_prompt=system_prompt, | |
tools=tools, | |
can_handoff_to=[ | |
"planner_agent", | |
"research_agent", | |
"reasoning_agent", | |
"code_agent", | |
], | |
) | |
logger.info("VideoAnalyzerAgent initialised successfully.") | |
return agent | |
except Exception as exc: # pylint: disable=broad-except | |
logger.error("Error during VideoAnalyzerAgent initialisation: %s", exc, exc_info=True) | |
raise | |
if __name__ == "__main__": | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
) | |
logger.info("Running video_analyzer_agent.py directly for testing …") | |
if not os.getenv("GEMINI_API_KEY"): | |
print("Error: GEMINI_API_KEY environment variable not set. Cannot run test.") | |
else: | |
try: | |
test_agent = initialize_video_analyzer_agent() | |
summary = download_video_and_analyze("https://www.youtube.com/watch?v=dQw4w9WgXcQ") | |
print("\n--- Gemini summary ---\n") | |
print(summary) | |
print("Video Analyzer Agent initialised successfully for testing.") | |
except Exception as exc: | |
print(f"Error during testing: {exc}") | |
test_agent = None | |
try: | |
# Test YouTube transcript tool directly | |
if YOUTUBE_TRANSCRIPT_API_AVAILABLE: | |
print("\nTesting YouTube transcript tool...") | |
# Example video: "Attention is All You Need" paper explanation | |
yt_url = "https://www.youtube.com/watch?v=TQQlZhbC5ps" | |
transcript = get_youtube_transcript(yt_url) | |
if not transcript.startswith("Error:"): | |
print(f"Transcript fetched (first 500 chars):\n{transcript[:500]}...") | |
else: | |
print(f"YouTube Transcript Fetch Failed: {transcript}") | |
else: | |
print("\nSkipping YouTube transcript test as youtube-transcript-api is not available.") | |
except Exception as e: | |
print(f"Error during testing: {e}") | |