GAIA_Agent / agents /video_analyzer_agent.py
Delanoe Pirard
clean 2
e1b7852
raw
history blame
12.9 kB
from __future__ import annotations
import logging
import os
import re
import shutil
from pathlib import Path
from typing import Optional
import cv2
import yt_dlp
from llama_index.core.agent.workflow import FunctionAgent
from llama_index.core.base.llms.types import TextBlock, ImageBlock, ChatMessage
from llama_index.core.tools import FunctionTool
from llama_index.llms.google_genai import GoogleGenAI
from tqdm import tqdm
from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound
# ---------------------------------------------------------------------------
# Environment setup & logging
# ---------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Prompt loader
# ---------------------------------------------------------------------------
def load_prompt_from_file(filename: str = "../prompts/video_analyzer_prompt.txt") -> str:
"""Load the system prompt for video analysis from *filename*.
Falls back to a minimal prompt if the file cannot be read.
"""
script_dir = Path(__file__).parent
prompt_path = (script_dir / filename).resolve()
try:
with prompt_path.open("r", encoding="utf-8") as fp:
prompt = fp.read()
logger.info("Successfully loaded system prompt from %s", prompt_path)
return prompt
except FileNotFoundError:
logger.error(
"Prompt file %s not found. Using fallback prompt.", prompt_path
)
except Exception as exc: # pylint: disable=broad-except
logger.error(
"Error loading prompt file %s: %s", prompt_path, exc, exc_info=True
)
# Fallback – keep it extremely short to save tokens
return (
"You are a video analyzer. Provide a factual, chronological "
"description of the video, identify key events, and summarise insights."
)
def extract_frames(video_path, output_dir, fps=1/2):
"""
Extract frames from video at specified FPS
Returns a list of (frame_path, timestamp) tuples
"""
os.makedirs(output_dir, exist_ok=True)
# Open video
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Error: Could not open video {video_path}")
return [], None
# Get video properties
video_fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = frame_count / video_fps
# Calculate frame interval
interval = int(video_fps / fps)
if interval < 1:
interval = 1
# Extract frames
frames = []
frame_idx = 0
with tqdm(total=frame_count, desc="Extracting frames") as pbar:
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_idx % interval == 0:
timestamp = frame_idx / video_fps
frame_path = os.path.join(output_dir, f"frame_{frame_idx:06d}.jpg")
cv2.imwrite(frame_path, frame)
frames.append((frame_path, timestamp))
frame_idx += 1
pbar.update(1)
cap.release()
return frames, duration
def download_video_and_analyze(video_url: str) -> str:
"""Download a video from *video_url* and return the local file path."""
llm_model_name = os.getenv("VIDEO_ANALYZER_LLM_MODEL", "models/gemini-1.5-pro")
gemini_api_key = os.getenv("GEMINI_API_KEY")
cookie_txt = "yt_cookie.txt"
ydl_opts = {
'format': 'best',
'outtmpl': os.path.join("downloaded_videos", 'temp_video.%(ext)s'),
'quiet': True,
'extract_flat': True,
'ignoreerrors': True,
"cookiefile": cookie_txt
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl_download:
ydl_download.download(video_url)
print(f"Processing video: {video_url}")
# Create temporary directory for frames
temp_dir = "frame_downloaded_videos"
os.makedirs(temp_dir, exist_ok=True)
# Extract frames
frames, duration = extract_frames(os.path.join("downloaded_videos", 'temp_video.mp4'), temp_dir)
if not frames:
logging.info(f"No frames extracted from {video_url}")
return f"No frames extracted from {video_url}"
blocks = []
text_block = TextBlock(text=load_prompt_from_file())
blocks.append(text_block)
for frame_path, timestamp in tqdm(frames, desc="Collecting frames"):
blocks.append(ImageBlock(path=frame_path))
llm = GoogleGenAI(api_key=gemini_api_key, model=llm_model_name)
logger.info("Using LLM model: %s", llm_model_name)
response = llm.chat([ChatMessage(role="user", blocks=blocks)])
# Clean up temporary files
shutil.rmtree(temp_dir)
os.remove(os.path.join("downloaded_videos", 'temp_video.mp4'))
return response.message.content
# --- Helper function to extract YouTube Video ID ---
def extract_video_id(url: str) -> Optional[str]:
"""Extracts the YouTube video ID from various URL formats."""
# Standard watch URL: https://www.youtube.com/watch?v=VIDEO_ID
pattern = re.compile(
r'^(?:https?://)?' # protocole optionnel
r'(?:www\.)?' # sous-domaine optionnel
r'youtube\.com/watch\?' # domaine et chemin fixe
r'(?:.*&)?' # éventuellement d'autres paramètres avant v=
r'v=([^&]+)' # capture de l'ID (tout jusqu'au prochain & ou fin)
)
match = pattern.search(url)
if match:
video_id = match.group(1)
return video_id # affiche "VIDEO_ID"
else:
print("Aucun ID trouvé")
return None
# --- YouTube Transcript Tool ---
def get_youtube_transcript(video_url_or_id: str, languages: str | None = None) -> str:
"""Fetches the transcript for a YouTube video using its URL or video ID.
Specify preferred languages as a list (e.g., ["en", "es"]).
Returns the transcript text or an error message.
"""
if languages is None:
languages = ["en"]
logger.info(f"Attempting to fetch YouTube transcript for: {video_url_or_id}")
video_id = extract_video_id(video_url_or_id)
if video_id is None or not video_id:
logger.error(f"Could not extract video ID from: {video_url_or_id}")
return f"Error: Invalid YouTube URL or Video ID format: {video_url_or_id}"
try:
# Fetch available transcripts
api = YouTubeTranscriptApi()
transcript_list = api.list(video_id)
# Try to find a transcript in the specified languages
transcript = transcript_list.find_transcript(languages)
# Fetch the actual transcript data (list of dicts)
transcript_data = transcript.fetch()
# Combine the text parts into a single string
full_transcript = " ".join(snippet.text for snippet in transcript_data)
full_transcript = " ".join(snippet.text for snippet in transcript_data)
logger.info(f"Successfully fetched transcript for video ID {video_id} in language {transcript.language}.")
return full_transcript
except TranscriptsDisabled:
logger.warning(f"Transcripts are disabled for video ID: {video_id}")
return f"Error: Transcripts are disabled for this video (ID: {video_id})."
except NoTranscriptFound as e:
logger.warning(
f"No transcript found for video ID {video_id} in languages {languages}. Available: {e.available_transcripts}")
# Try fetching any available transcript if specific languages failed
try:
logger.info(f"Attempting to fetch any available transcript for {video_id}")
any_transcript = transcript_list.find_generated_transcript(
transcript_list.manually_created_transcripts.keys() or transcript_list.generated_transcripts.keys())
any_transcript_data = any_transcript.fetch()
full_transcript = " ".join([item["text"] for item in any_transcript_data])
logger.info(
f"Successfully fetched fallback transcript for video ID {video_id} in language {any_transcript.language}.")
return full_transcript
except Exception as fallback_e:
logger.error(
f"Could not find any transcript for video ID {video_id}. Original error: {e}. Fallback error: {fallback_e}")
return f"Error: No transcript found for video ID {video_id} in languages {languages} or any fallback language."
except Exception as e:
logger.error(f"Unexpected error fetching transcript for video ID {video_id}: {e}", exc_info=True)
return f"Error fetching transcript: {e}"
download_video_and_analyze_tool = FunctionTool.from_defaults(
name="download_video_and_analyze",
description=(
"Downloads a video (YouTube or direct URL), samples representative frames, "
"and feeds them to Gemini for multimodal analysis—returning a rich textual summary "
"of the visual content."
),
fn=download_video_and_analyze,
)
youtube_transcript_tool = FunctionTool.from_defaults(
fn=get_youtube_transcript,
name="get_youtube_transcript",
description=(
"(YouTube) Fetches the transcript text for a given YouTube video URL or video ID. "
"Specify preferred languages (e.g., 'en', 'es'). Returns transcript or error."
)
)
# ---------------------------------------------------------------------------
# Agent factory
# ---------------------------------------------------------------------------
def initialize_video_analyzer_agent() -> FunctionAgent:
"""Initialise and return a *video_analyzer_agent* `FunctionAgent`."""
logger.info("Initialising VideoAnalyzerAgent …")
llm_model_name = os.getenv("VIDEO_ANALYZER_LLM_MODEL", "models/gemini-1.5-pro")
gemini_api_key = os.getenv("GEMINI_API_KEY")
if not gemini_api_key:
logger.error("GEMINI_API_KEY not found in environment variables.")
raise ValueError("GEMINI_API_KEY must be set")
try:
llm = GoogleGenAI(api_key=gemini_api_key, model=llm_model_name)
logger.info("Using LLM model: %s", llm_model_name)
system_prompt = load_prompt_from_file()
tools = [download_video_and_analyze_tool, youtube_transcript_tool]
agent = FunctionAgent(
name="video_analyzer_agent",
description=(
"VideoAnalyzerAgent inspects video files using Gemini's multimodal "
"video understanding capabilities, producing factual scene analysis, "
"temporal segmentation, and concise summaries as guided by the system "
"prompt."
),
llm=llm,
system_prompt=system_prompt,
tools=tools,
can_handoff_to=[
"planner_agent",
"research_agent",
"reasoning_agent",
"code_agent",
],
)
logger.info("VideoAnalyzerAgent initialised successfully.")
return agent
except Exception as exc: # pylint: disable=broad-except
logger.error("Error during VideoAnalyzerAgent initialisation: %s", exc, exc_info=True)
raise
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger.info("Running video_analyzer_agent.py directly for testing …")
if not os.getenv("GEMINI_API_KEY"):
print("Error: GEMINI_API_KEY environment variable not set. Cannot run test.")
else:
try:
test_agent = initialize_video_analyzer_agent()
summary = download_video_and_analyze("https://www.youtube.com/watch?v=dQw4w9WgXcQ")
print("\n--- Gemini summary ---\n")
print(summary)
print("Video Analyzer Agent initialised successfully for testing.")
except Exception as exc:
print(f"Error during testing: {exc}")
test_agent = None
try:
# Test YouTube transcript tool directly
if YOUTUBE_TRANSCRIPT_API_AVAILABLE:
print("\nTesting YouTube transcript tool...")
# Example video: "Attention is All You Need" paper explanation
yt_url = "https://www.youtube.com/watch?v=TQQlZhbC5ps"
transcript = get_youtube_transcript(yt_url)
if not transcript.startswith("Error:"):
print(f"Transcript fetched (first 500 chars):\n{transcript[:500]}...")
else:
print(f"YouTube Transcript Fetch Failed: {transcript}")
else:
print("\nSkipping YouTube transcript test as youtube-transcript-api is not available.")
except Exception as e:
print(f"Error during testing: {e}")