|
from setup import * |
|
import re |
|
import requests |
|
from typing import Annotated, Sequence, List, Optional |
|
from typing_extensions import TypedDict |
|
|
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage |
|
from langgraph.graph.message import add_messages |
|
from langgraph.graph import START, StateGraph, END |
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[Sequence[BaseMessage], add_messages] |
|
queries : List[str] |
|
link_list : Optional[List] |
|
industry : Optional[str] |
|
company: Optional[str] |
|
|
|
|
|
|
|
|
|
def assistant(state: AgentState): |
|
assistant_sys_msg = SystemMessage(content='''You are a highly intelligent and helpful assistant. Your primary task is to analyze user queries and determine whether the query: |
|
|
|
Refers to an industry (general context) |
|
Refers to a specific company (e.g., mentions a company's name explicitly). |
|
|
|
For every query: |
|
Check for company names, brands, or proper nouns that indicate a specific entity. |
|
While analyzing the company industry be specific as possible. |
|
Return the company and industry name in the query |
|
if you can't find a industry name, return an empty string. |
|
|
|
Example 1: |
|
Query: "GenAI in MRF Tyres" |
|
Company: "MRF Tyres" |
|
Industry: "Tires and rubber products" |
|
|
|
Example 2: |
|
Query: "GenAI in the healthcare industry" |
|
Company: "" |
|
Industry: "Healthcare" |
|
''') |
|
return {'messages': [llm.invoke([assistant_sys_msg] + state["messages"])]} |
|
|
|
|
|
|
|
def company_and_industry_query(state: AgentState): |
|
text = state['messages'][-1].content |
|
|
|
|
|
company_pattern = r'Company:\s*"([^"]+)"' |
|
industry_pattern = r'Industry:\s*"([^"]+)"' |
|
|
|
|
|
company_match = re.search(company_pattern, text) |
|
industry_match = re.search(industry_pattern, text) |
|
|
|
|
|
company_name = company_match.group(1) if company_match else None |
|
industry_name = industry_match.group(1) if industry_match else None |
|
queries = [] |
|
if company_name: |
|
queries.extend([f'{company_name} Annual report latest AND {company_name} website AND no PDF results', |
|
f'{company_name} GenAI applications', |
|
f'{company_name} key offerings and strategic focus areas (e.g., operations, supply chain, customer experience)', |
|
]) |
|
|
|
if industry_name: |
|
queries.extend([ |
|
f'{industry_name} report latest AND (Mckinsey OR deloitte OR nexocode OR research report)', |
|
f'{industry_name} GenAI applications', |
|
f'{industry_name} trends, challenges and oppurtunities' |
|
]) |
|
|
|
return {'queries': queries, 'company': company_name, 'industry': industry_name} |
|
|
|
|
|
def web_scraping(state: AgentState): |
|
queries = state['queries'] |
|
link_list = [] |
|
for query in queries: |
|
query_results = tavily_search.invoke({"query": query}) |
|
link_list.extend(query_results) |
|
|
|
return {'link_list': link_list} |
|
|
|
|
|
|
|
def research_agent(user_query: str): |
|
builder = StateGraph(AgentState) |
|
builder.add_node('assistant', assistant) |
|
builder.add_node('names_extract', company_and_industry_query) |
|
builder.add_node('web_scraping', web_scraping) |
|
|
|
builder.add_edge(START, "assistant") |
|
builder.add_edge("assistant", "names_extract") |
|
builder.add_edge("names_extract", 'web_scraping') |
|
builder.add_edge("web_scraping", END) |
|
|
|
|
|
memory = MemorySaver() |
|
react_graph = builder.compile(checkpointer=memory) |
|
|
|
config = {'configurable': {'thread_id':'1'}} |
|
messages = [HumanMessage(content=user_query)] |
|
agentstate_result = react_graph.invoke({'messages': messages}, config) |
|
|
|
return agentstate_result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|