File size: 3,844 Bytes
5fd03ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384a4dc
5fd03ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from setup import *
import re
import requests
from typing import Annotated, Sequence, List, Optional
from typing_extensions import TypedDict

from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langgraph.graph.message import add_messages
from langgraph.graph import START, StateGraph, END
from langgraph.checkpoint.memory import MemorySaver


# Research agent
class AgentState(TypedDict):
  messages: Annotated[Sequence[BaseMessage], add_messages]
  queries : List[str]
  link_list : Optional[List]
  industry : Optional[str]
  company: Optional[str]



# Node
def assistant(state: AgentState):
  assistant_sys_msg = SystemMessage(content='''You are a highly intelligent and helpful assistant. Your primary task is to analyze user queries and determine whether the query:

    Refers to an industry (general context)
    Refers to a specific company (e.g., mentions a company's name explicitly).

    For every query:
    Check for company names, brands, or proper nouns that indicate a specific entity.
    While analyzing the company industry be specific as possible.
    Return the company and industry name in the query
    if you can't find a industry name, return an empty string.
                            
    Example 1:
    Query: "GenAI in MRF Tyres"
    Company: "MRF Tyres"
    Industry: "Tires and rubber products"

    Example 2:  
    Query: "GenAI in the healthcare industry"
    Company: ""
    Industry: "Healthcare"
    ''')
  return {'messages': [llm.invoke([assistant_sys_msg] + state["messages"])]}



def company_and_industry_query(state: AgentState):
  text = state['messages'][-1].content 
  
  # Define patterns for extracting company and industry
  company_pattern = r'Company:\s*"([^"]+)"'
  industry_pattern = r'Industry:\s*"([^"]+)"'
  
  # Search for matches
  company_match = re.search(company_pattern, text)
  industry_match = re.search(industry_pattern, text)
  
  # Extract matched groups or return None if not found
  company_name = company_match.group(1) if company_match else None
  industry_name = industry_match.group(1) if industry_match else None
  queries = []
  if company_name:
      queries.extend([f'{company_name} Annual report latest AND {company_name} website AND no PDF results', 
                      f'{company_name} GenAI applications',
                      f'{company_name} key offerings and strategic focus areas (e.g., operations, supply chain, customer experience)', 
                      ])

  if industry_name:
      queries.extend([
         f'{industry_name} report latest AND (Mckinsey OR deloitte OR nexocode OR research report)', 
                      f'{industry_name} GenAI applications',
                      f'{industry_name} trends, challenges and oppurtunities'
                      ])

  return {'queries': queries, 'company': company_name, 'industry': industry_name}


def web_scraping(state: AgentState):
  queries = state['queries']
  link_list = []
  for query in queries:
      query_results = tavily_search.invoke({"query": query}) 
      link_list.extend(query_results)

  return {'link_list': link_list}


# Agent Graph
def research_agent(user_query: str):
  builder = StateGraph(AgentState)
  builder.add_node('assistant', assistant)
  builder.add_node('names_extract', company_and_industry_query)
  builder.add_node('web_scraping', web_scraping)

  builder.add_edge(START, "assistant")
  builder.add_edge("assistant", "names_extract")
  builder.add_edge("names_extract", 'web_scraping')
  builder.add_edge("web_scraping", END)

  # memory
  memory = MemorySaver()
  react_graph = builder.compile(checkpointer=memory)
  
  config = {'configurable': {'thread_id':'1'}}
  messages = [HumanMessage(content=user_query)]
  agentstate_result = react_graph.invoke({'messages': messages}, config)
  
  return agentstate_result