Spaces:
Running
Running
import google.generativeai as genai | |
from langchain_core.messages import HumanMessage, SystemMessage | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper, WikipediaAPIWrapper | |
from langchain.agents import Tool, AgentExecutor, ConversationalAgent, initialize_agent | |
from langchain.memory import ConversationBufferMemory | |
from langchain.tools import Tool | |
from google.generativeai.types import HarmCategory, HarmBlockThreshold | |
from PIL import Image | |
import os | |
import tempfile | |
import time | |
import re | |
import json | |
from typing import List, Optional, Dict, Any | |
from urllib.parse import urlparse | |
import requests | |
import yt_dlp | |
from bs4 import BeautifulSoup | |
from difflib import SequenceMatcher | |
class Agent: | |
def __init__(self, model_name:str ="gemini", api_key:str ="BasicAgent"): | |
self.model = model_name | |
self.api_key = api_key | |
# if model_name starts with "gemini", use the gemini agent | |
self.tools = [ | |
Tool( | |
name='web_search', | |
func=self._web_search, | |
description="A tool to search the web for information." | |
), | |
Tool( | |
name='analyze_video', | |
func=self._analyze_video, | |
description="A tool to analyze video content." | |
), | |
Tool( | |
name='analyze_image', | |
func=self._analyze_image, | |
description="A tool to analyze image content." | |
), | |
Tool( | |
name='analyze_list', | |
func=self._analyze_list, | |
description="A tool to analyze a list." | |
), | |
Tool( | |
name='analyze_table', | |
func=self._analyze_table, | |
description="A tool to analyze a table." | |
), | |
Tool( | |
name='analyze_text', | |
func=self._analyze_text, | |
description="A tool to analyze text content." | |
), | |
Tool( | |
name='analyze_url', | |
func=self._analyze_url, | |
description="A tool to analyze a URL." | |
), | |
Tool( | |
name='wikipedia_search', | |
func=WikipediaAPIWrapper().run, | |
description="A tool to search Wikipedia." | |
), | |
] | |
self.memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
return_messages=True, | |
output_key="output", | |
input_key="input" | |
) | |
self.llm = self._initialize_model(model_name, api_key) | |
self.agent = initialize_agent() | |
def _initialize_model(self, model_name:str, api_key:str): | |
if model_name.startswith("gemini"): | |
return self._initialize_gemini(model_name) | |
else: | |
raise ValueError(f"Unsupported model name: {model_name}. Please use a valid model name.") | |
def _initialize_gemini(self, model_name:str = "gemini-2.0-flash"): | |
generation_config = { | |
"temperature": 0.0, | |
"max_output_tokens": 2000, | |
"candidate_count": 1, | |
} | |
safety_settings = { | |
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | |
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | |
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | |
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | |
} | |
return ChatGoogleGenerativeAI( | |
model=model_name, | |
google_api_key=self.api_key, | |
temperature=0, | |
max_output_tokens=2000, | |
generation_config=generation_config, | |
safety_settings=safety_settings, | |
system_message=SystemMessage(content=( | |
"You are a precise AI assistant that helps users find information and analyze content. " | |
"You can directly understand and analyze YouTube videos, images, and other content. " | |
"When analyzing videos, focus on relevant details like dialogue, text, and key visual elements. " | |
"For lists, tables, and structured data, ensure proper formatting and organization. " | |
"If you need additional context, clearly explain what is needed." | |
)) | |
) | |
def initialize_agent(self): | |
PREAMBLE = ( | |
"You are a helpful assistant. You can use the tools provided to search the web, analyze videos, images, lists, and tables. " | |
"Please provide clear and concise answers." | |
"TOOLS: You have access to the following tools: " | |
) | |
FORMAT_PROMPT = ( | |
"To use a tool, follow this format: " | |
"Though: Do I need to use a tool? " | |
"Action: the action to take, should be one of the {{tool_names}} " | |
"Action Input: the input to the action " | |
"Observation: the result of the action " | |
"When you have the final answer or if you don't need to use a tool, you MUST use the format: " | |
"Thought: Do I need to use a tool? " | |
"Final Answer: {your final response} " | |
"" | |
) | |
POSTFIX = ( | |
"Previous conersation: {chat_history} " | |
"{chat_history} " | |
"New question: {input} " | |
"{agent_scratchpad} " | |
) | |
agent = ConversationalAgent.from_agent_and_tools( | |
llm=self.llm, | |
tools=self.tools, | |
prefix=PREAMBLE, | |
suffix=POSTFIX, | |
format_instructions=FORMAT_PROMPT, | |
handle_tool_errors=True, | |
input_variables=["input", "chat_history", "agent_scratchpad", "tool_names"], | |
) | |
return AgentExecutor.from_agent_and_tools( | |
agent=agent, | |
tools=self.tools, | |
memory=self.memory, | |
verbose=True, | |
handle_parsing_errors=True, | |
max_iterations=3, | |
return_only_outputs=True, | |
) | |
def run(self, query: str) -> str: | |
""" | |
Run the agent with the given input text. | |
""" | |
max_retries = 3 | |
retry_delay = 2 | |
for attempt in range(max_retries): | |
try: | |
result = self.agent.run(input=query) | |
return result | |
except Exception as e: | |
sleep_time = retry_delay * (attempt + 1) | |
print(f"Attempt {attempt + 1} failed: {e}. Retrying in {sleep_time} seconds...") | |
time.sleep(sleep_time) | |
continue | |
return f"Error: request failed after {max_retries} attempts. Please try again later." | |
print(f"All questions have been answered.") | |
def _web_search(self, query: str, site: Optional[str] = None) -> str: | |
""" | |
Perform a web search using DuckDuckGo and return the top result. | |
""" | |
search = DuckDuckGoSearchAPIWrapper(max_results=5) | |
results = search.run(f"{query} {f'site:{site}' if site else ''}") | |
if results: | |
return results | |
else: | |
return "No results found." | |
def _analyze_video(self, video_url: str) -> str: | |
""" | |
Analyze a YouTube video and return the transcript. | |
""" | |
ydl_opts = { | |
'quiet': True, | |
'skip_download': True, | |
'no_warnings': True, | |
'extract_flat': True, | |
'no_playlist': True, | |
'youtube_include_dash_manifest': False | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
try: | |
info = ydl.extract_info(video_url, download=False, process=False) | |
if 'entries' in info: | |
info = info['entries'][0] | |
title = info.get('title', 'No title available.') | |
description = info.get('description', 'No transcript available.') | |
prompt = f"""Please analyze this YouTube video: | |
Title: {title} | |
URL: {video_url} | |
Description: {description} | |
Please provide a detailed analysis focusing on: | |
1. Main topic and key points from the title and description | |
2. Expected visual elements and scenes | |
3. Overall message or purpose | |
4. Target audience""" | |
messages = [HumanMessage(content=prompt)] | |
response = self.llm.invoke(messages) | |
return response.content if hasattr(response, 'content') else str(response) | |
except Exception as e: | |
if 'Sign in to confirm' in str(e): | |
return "This video requires sign-in. Please provide a different video URL." | |
return f"Error accessing video: {str(e)}" | |
def _analyze_image(self, image_url: str) -> str: | |
""" | |
Analyze an image and return a description. | |
""" | |
try: | |
response = requests.get(image_url) | |
if response.status_code == 200: | |
with tempfile.NamedTemporaryFile(delete=True) as temp_file: | |
temp_file.write(response.content) | |
temp_file.flush() | |
image = Image.open(temp_file.name) | |
prompt = f"Please analyze this image: {image_url}. Provide a detailed description of the content with focus on the following aspects:\n1. Main subjects and objects in the image\n2. Colors, textures, and patterns\n3. Overall mood or atmosphere\n4. Any text or symbols present in the image\n5. Possible context or background information" | |
messages = [HumanMessage(content=prompt)] | |
response = self.llm.invoke(messages) | |
return response.content if hasattr(response, 'content') else str(response) | |
else: | |
return f"Error accessing image: {response.status_code}" | |
except Exception as e: | |
return f"Error processing image: {str(e)}" | |
def _analyze_list(self, input_list: List[str]) -> str: | |
""" | |
Analyze a list and return a summary. | |
""" | |
prompt = f"Please analyze this list: {input_list}. Provide a detailed summary focusing on:\n1. Main themes or categories\n2. Key items or elements\n3. Possible relationships or connections\n4. Any patterns or trends observed" | |
messages = [HumanMessage(content=prompt)] | |
response = self.llm.invoke(messages) | |
return response.content if hasattr(response, 'content') else str(response) | |
def _analyze_table(self, input_table: List[List[Any]]) -> str: | |
""" | |
Analyze a table and return a summary. | |
""" | |
prompt = f"Please analyze this table: {input_table}. Provide a detailed summary focusing on:\n1. Main themes or categories\n2. Key items or elements\n3. Possible relationships or connections\n4. Any patterns or trends observed" | |
messages = [HumanMessage(content=prompt)] | |
response = self.llm.invoke(messages) | |
return response.content if hasattr(response, 'content') else str(response) | |
def _analyze_text(self, text: str) -> str: | |
""" | |
Analyze a text and return a summary. | |
""" | |
prompt = f"Please analyze this text: {text}. Provide a detailed summary focusing on:\n1. Main themes or categories\n2. Key items or elements\n3. Possible relationships or connections\n4. Any patterns or trends observed" | |
messages = [HumanMessage(content=prompt)] | |
response = self.llm.invoke(messages) | |
return response.content if hasattr(response, 'content') else str(response) | |
def _analyze_url(self, url: str) -> str: | |
""" | |
Analyze a URL and return a summary. | |
""" | |
try: | |
response = requests.get(url) | |
if response.status_code == 200: | |
content = response.text | |
soup = BeautifulSoup(content, 'html.parser') | |
text = soup.get_text() | |
prompt = f"Please analyze this URL: {url}. Provide a detailed summary focusing on:\n1. Main themes or categories\n2. Key items or elements\n3. Possible relationships or connections\n4. Any patterns or trends observed" | |
messages = [HumanMessage(content=prompt)] | |
response = self.llm.invoke(messages) | |
return response.content if hasattr(response, 'content') else str(response) | |
else: | |
return f"Error accessing URL: {response.status_code}" | |
except Exception as e: | |
return f"Error processing URL: {str(e)}" | |