Dharma20 commited on
Commit
5fd03ff
·
verified ·
1 Parent(s): ce4be9a

Update agents.py

Browse files
Files changed (1) hide show
  1. agents.py +117 -121
agents.py CHANGED
@@ -1,121 +1,117 @@
1
- from setup import *
2
- import re
3
- import requests
4
- from typing import Annotated, Sequence, List, Optional
5
- from typing_extensions import TypedDict
6
-
7
- from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
8
- from langgraph.graph.message import add_messages
9
- from langgraph.graph import START, StateGraph, END
10
- from langgraph.checkpoint.memory import MemorySaver
11
-
12
-
13
- # Research agent
14
- class AgentState(TypedDict):
15
- messages: Annotated[Sequence[BaseMessage], add_messages]
16
- queries : List[str]
17
- link_list : Optional[List]
18
- industry : Optional[str]
19
- company: Optional[str]
20
-
21
-
22
-
23
- # Node
24
- def assistant(state: AgentState):
25
- assistant_sys_msg = SystemMessage(content='''You are a highly intelligent and helpful assistant. Your primary task is to analyze user queries and determine whether the query:
26
-
27
- Refers to an industry (general context)
28
- Refers to a specific company (e.g., mentions a company's name explicitly).
29
-
30
- For every query:
31
- Check for company names, brands, or proper nouns that indicate a specific entity.
32
- While analyzing the company industry be specific as possible.
33
- Return the company and industry name in the query
34
- if you can't find a industry name, return an empty string.
35
-
36
- Example 1:
37
- Query: "GenAI in MRF Tyres"
38
- Company: "MRF Tyres"
39
- Industry: "Tires and rubber products"
40
-
41
- Example 2:
42
- Query: "GenAI in the healthcare industry"
43
- Company: ""
44
- Industry: "Healthcare"
45
- ''')
46
- return {'messages': [llm.invoke([assistant_sys_msg] + state["messages"])]}
47
-
48
-
49
-
50
- def company_and_industry_query(state: AgentState):
51
- print('--extract_company_and_industry--entered--')
52
- text = state['messages'][-1].content
53
-
54
- # Define patterns for extracting company and industry
55
- company_pattern = r'Company:\s*"([^"]+)"'
56
- industry_pattern = r'Industry:\s*"([^"]+)"'
57
-
58
- # Search for matches
59
- company_match = re.search(company_pattern, text)
60
- industry_match = re.search(industry_pattern, text)
61
-
62
- # Extract matched groups or return None if not found
63
- company_name = company_match.group(1) if company_match else None
64
- industry_name = industry_match.group(1) if industry_match else None
65
- queries = []
66
- if company_name:
67
- queries.extend([f'{company_name} Annual report latest AND {company_name} website AND no PDF results',
68
- f'{company_name} GenAI applications',
69
- f'{company_name} key offerings and strategic focus areas (e.g., operations, supply chain, customer experience)',
70
- ])
71
-
72
- if industry_name:
73
- queries.extend([
74
- f'{industry_name} report latest mckinsey, deloitte, nexocode',
75
- f'{industry_name} GenAI applications',
76
- f'{industry_name} trends, challenges and oppurtunities'
77
- ])
78
-
79
- print('--extract_company_and_industry--finished--', queries)
80
- return {'queries': queries, 'company': company_name, 'industry': industry_name}
81
-
82
-
83
- def web_scraping(state: AgentState):
84
- print('--web_scraping--entered--')
85
- queries = state['queries']
86
- link_list = []
87
- for query in queries:
88
- query_results = tavily_search.invoke({"query": query})
89
- link_list.extend(query_results)
90
-
91
- print('--web_scraping--finished--')
92
- return {'link_list': link_list}
93
-
94
-
95
- # Agent Graph
96
- def research_agent(user_query: str):
97
- builder = StateGraph(AgentState)
98
- builder.add_node('assistant', assistant)
99
- builder.add_node('names_extract', company_and_industry_query)
100
- builder.add_node('web_scraping', web_scraping)
101
-
102
- builder.add_edge(START, "assistant")
103
- builder.add_edge("assistant", "names_extract")
104
- builder.add_edge("names_extract", 'web_scraping')
105
- builder.add_edge("web_scraping", END)
106
-
107
- # memory
108
- memory = MemorySaver()
109
- react_graph = builder.compile(checkpointer=memory)
110
-
111
- config = {'configurable': {'thread_id':'1'}}
112
- messages = [HumanMessage(content=user_query)]
113
- agentstate_result = react_graph.invoke({'messages': messages}, config)
114
-
115
- return agentstate_result
116
-
117
-
118
-
119
-
120
-
121
-
 
1
+ from setup import *
2
+ import re
3
+ import requests
4
+ from typing import Annotated, Sequence, List, Optional
5
+ from typing_extensions import TypedDict
6
+
7
+ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
8
+ from langgraph.graph.message import add_messages
9
+ from langgraph.graph import START, StateGraph, END
10
+ from langgraph.checkpoint.memory import MemorySaver
11
+
12
+
13
+ # Research agent
14
+ class AgentState(TypedDict):
15
+ messages: Annotated[Sequence[BaseMessage], add_messages]
16
+ queries : List[str]
17
+ link_list : Optional[List]
18
+ industry : Optional[str]
19
+ company: Optional[str]
20
+
21
+
22
+
23
+ # Node
24
+ def assistant(state: AgentState):
25
+ assistant_sys_msg = SystemMessage(content='''You are a highly intelligent and helpful assistant. Your primary task is to analyze user queries and determine whether the query:
26
+
27
+ Refers to an industry (general context)
28
+ Refers to a specific company (e.g., mentions a company's name explicitly).
29
+
30
+ For every query:
31
+ Check for company names, brands, or proper nouns that indicate a specific entity.
32
+ While analyzing the company industry be specific as possible.
33
+ Return the company and industry name in the query
34
+ if you can't find a industry name, return an empty string.
35
+
36
+ Example 1:
37
+ Query: "GenAI in MRF Tyres"
38
+ Company: "MRF Tyres"
39
+ Industry: "Tires and rubber products"
40
+
41
+ Example 2:
42
+ Query: "GenAI in the healthcare industry"
43
+ Company: ""
44
+ Industry: "Healthcare"
45
+ ''')
46
+ return {'messages': [llm.invoke([assistant_sys_msg] + state["messages"])]}
47
+
48
+
49
+
50
+ def company_and_industry_query(state: AgentState):
51
+ text = state['messages'][-1].content
52
+
53
+ # Define patterns for extracting company and industry
54
+ company_pattern = r'Company:\s*"([^"]+)"'
55
+ industry_pattern = r'Industry:\s*"([^"]+)"'
56
+
57
+ # Search for matches
58
+ company_match = re.search(company_pattern, text)
59
+ industry_match = re.search(industry_pattern, text)
60
+
61
+ # Extract matched groups or return None if not found
62
+ company_name = company_match.group(1) if company_match else None
63
+ industry_name = industry_match.group(1) if industry_match else None
64
+ queries = []
65
+ if company_name:
66
+ queries.extend([f'{company_name} Annual report latest AND {company_name} website AND no PDF results',
67
+ f'{company_name} GenAI applications',
68
+ f'{company_name} key offerings and strategic focus areas (e.g., operations, supply chain, customer experience)',
69
+ ])
70
+
71
+ if industry_name:
72
+ queries.extend([
73
+ f'{industry_name} report latest mckinsey, deloitte, nexocode',
74
+ f'{industry_name} GenAI applications',
75
+ f'{industry_name} trends, challenges and oppurtunities'
76
+ ])
77
+
78
+ return {'queries': queries, 'company': company_name, 'industry': industry_name}
79
+
80
+
81
+ def web_scraping(state: AgentState):
82
+ queries = state['queries']
83
+ link_list = []
84
+ for query in queries:
85
+ query_results = tavily_search.invoke({"query": query})
86
+ link_list.extend(query_results)
87
+
88
+ return {'link_list': link_list}
89
+
90
+
91
+ # Agent Graph
92
+ def research_agent(user_query: str):
93
+ builder = StateGraph(AgentState)
94
+ builder.add_node('assistant', assistant)
95
+ builder.add_node('names_extract', company_and_industry_query)
96
+ builder.add_node('web_scraping', web_scraping)
97
+
98
+ builder.add_edge(START, "assistant")
99
+ builder.add_edge("assistant", "names_extract")
100
+ builder.add_edge("names_extract", 'web_scraping')
101
+ builder.add_edge("web_scraping", END)
102
+
103
+ # memory
104
+ memory = MemorySaver()
105
+ react_graph = builder.compile(checkpointer=memory)
106
+
107
+ config = {'configurable': {'thread_id':'1'}}
108
+ messages = [HumanMessage(content=user_query)]
109
+ agentstate_result = react_graph.invoke({'messages': messages}, config)
110
+
111
+ return agentstate_result
112
+
113
+
114
+
115
+
116
+
117
+