File size: 3,844 Bytes
5fd03ff 384a4dc 5fd03ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
from setup import *
import re
import requests
from typing import Annotated, Sequence, List, Optional
from typing_extensions import TypedDict
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langgraph.graph.message import add_messages
from langgraph.graph import START, StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
# Research agent
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], add_messages]
queries : List[str]
link_list : Optional[List]
industry : Optional[str]
company: Optional[str]
# Node
def assistant(state: AgentState):
assistant_sys_msg = SystemMessage(content='''You are a highly intelligent and helpful assistant. Your primary task is to analyze user queries and determine whether the query:
Refers to an industry (general context)
Refers to a specific company (e.g., mentions a company's name explicitly).
For every query:
Check for company names, brands, or proper nouns that indicate a specific entity.
While analyzing the company industry be specific as possible.
Return the company and industry name in the query
if you can't find a industry name, return an empty string.
Example 1:
Query: "GenAI in MRF Tyres"
Company: "MRF Tyres"
Industry: "Tires and rubber products"
Example 2:
Query: "GenAI in the healthcare industry"
Company: ""
Industry: "Healthcare"
''')
return {'messages': [llm.invoke([assistant_sys_msg] + state["messages"])]}
def company_and_industry_query(state: AgentState):
text = state['messages'][-1].content
# Define patterns for extracting company and industry
company_pattern = r'Company:\s*"([^"]+)"'
industry_pattern = r'Industry:\s*"([^"]+)"'
# Search for matches
company_match = re.search(company_pattern, text)
industry_match = re.search(industry_pattern, text)
# Extract matched groups or return None if not found
company_name = company_match.group(1) if company_match else None
industry_name = industry_match.group(1) if industry_match else None
queries = []
if company_name:
queries.extend([f'{company_name} Annual report latest AND {company_name} website AND no PDF results',
f'{company_name} GenAI applications',
f'{company_name} key offerings and strategic focus areas (e.g., operations, supply chain, customer experience)',
])
if industry_name:
queries.extend([
f'{industry_name} report latest AND (Mckinsey OR deloitte OR nexocode OR research report)',
f'{industry_name} GenAI applications',
f'{industry_name} trends, challenges and oppurtunities'
])
return {'queries': queries, 'company': company_name, 'industry': industry_name}
def web_scraping(state: AgentState):
queries = state['queries']
link_list = []
for query in queries:
query_results = tavily_search.invoke({"query": query})
link_list.extend(query_results)
return {'link_list': link_list}
# Agent Graph
def research_agent(user_query: str):
builder = StateGraph(AgentState)
builder.add_node('assistant', assistant)
builder.add_node('names_extract', company_and_industry_query)
builder.add_node('web_scraping', web_scraping)
builder.add_edge(START, "assistant")
builder.add_edge("assistant", "names_extract")
builder.add_edge("names_extract", 'web_scraping')
builder.add_edge("web_scraping", END)
# memory
memory = MemorySaver()
react_graph = builder.compile(checkpointer=memory)
config = {'configurable': {'thread_id':'1'}}
messages = [HumanMessage(content=user_query)]
agentstate_result = react_graph.invoke({'messages': messages}, config)
return agentstate_result
|