Arpit-Bansal commited on
Commit
9e67c92
·
1 Parent(s): 975d8af

enhanced_query

Browse files
Files changed (3) hide show
  1. generate.py +0 -8
  2. langgraph_agent.py +22 -16
  3. requirements.txt +3 -0
generate.py DELETED
@@ -1,8 +0,0 @@
1
- from google import genai
2
- from dotenv import load_dotenv
3
- from os import getenv
4
-
5
- load_dotenv()
6
-
7
- GEMINI_API_KEY = getenv("GEMINI_API_KEY")
8
- from .constants import GEMINI_API_KEY
 
 
 
 
 
 
 
 
 
langgraph_agent.py CHANGED
@@ -12,21 +12,6 @@ class AgentState(TypedDict):
12
  context: List[str]
13
  response: str
14
 
15
-
16
-
17
- # def query_classifier(state: AgentState) -> AgentState:
18
- # """Determine if the query requires RAG retrieval based on keywords.
19
- # Is not continued anymore, will be removed in future."""
20
- # query_lower = state["query"].lower()
21
- # rag_keywords = [
22
- # "scheme", "schemes", "program", "programs", "policy", "policies",
23
- # "public health engineering", "phe", "public health", "government",
24
- # "benefit", "financial", "assistance", "aid", "initiative", "yojana",
25
- # ]
26
-
27
- # state["requires_rag"] = any(keyword in query_lower for keyword in rag_keywords)
28
- # return state
29
-
30
  def query_classifier(state: AgentState) -> AgentState:
31
  """Updated classifier to use LLM for intent classification."""
32
 
@@ -44,6 +29,25 @@ def query_classifier(state: AgentState) -> AgentState:
44
  state["requires_rag"] = "yes" in result.lower()
45
  return state
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def retrieve_documents(state: AgentState) -> AgentState:
48
  """Retrieve documents from vector store if needed."""
49
  if state["requires_rag"]:
@@ -121,17 +125,19 @@ def create_agent_workflow():
121
  workflow = StateGraph(AgentState)
122
 
123
  # Add nodes
 
124
  workflow.add_node("classifier", query_classifier)
125
  workflow.add_node("retriever", retrieve_documents)
126
  workflow.add_node("responder", generate_response)
127
 
128
  # Create edges
 
129
  workflow.add_edge("classifier", "retriever")
130
  workflow.add_edge("retriever", "responder")
131
  workflow.add_edge("responder", END)
132
 
133
  # Set the entry point
134
- workflow.set_entry_point("classifier")
135
 
136
  # Compile the graph
137
  return workflow.compile()
 
12
  context: List[str]
13
  response: str
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def query_classifier(state: AgentState) -> AgentState:
16
  """Updated classifier to use LLM for intent classification."""
17
 
 
29
  state["requires_rag"] = "yes" in result.lower()
30
  return state
31
 
32
+ def enhance_query(state:AgentState) -> AgentState:
33
+ """Enhance the query with user data and context."""
34
+ previous_conversation = state.get("previous_conversation", "")
35
+ user_data = state.get("user_data", {})
36
+ query = state.get("query", "")
37
+
38
+ query_enhancement_prompt = f"""
39
+ Enhance the following query with user data and previous conversation context so it uses the previous conversation and user data.
40
+ To be used for generating a more relevant and personalized response.
41
+ Previous Conversation: {previous_conversation}
42
+ User Data: {user_data}
43
+ Current Query: {query}
44
+ Only write the enhanced query. No other text."""
45
+ result = llm.predict(query_enhancement_prompt)
46
+ print("Enhanced query: ", result)
47
+ state["query"] = result
48
+
49
+ return state
50
+
51
  def retrieve_documents(state: AgentState) -> AgentState:
52
  """Retrieve documents from vector store if needed."""
53
  if state["requires_rag"]:
 
125
  workflow = StateGraph(AgentState)
126
 
127
  # Add nodes
128
+ workflow.add_node("enhance_query", enhance_query)
129
  workflow.add_node("classifier", query_classifier)
130
  workflow.add_node("retriever", retrieve_documents)
131
  workflow.add_node("responder", generate_response)
132
 
133
  # Create edges
134
+ workflow.add_edge("enhance_query", "classifier")
135
  workflow.add_edge("classifier", "retriever")
136
  workflow.add_edge("retriever", "responder")
137
  workflow.add_edge("responder", END)
138
 
139
  # Set the entry point
140
+ workflow.set_entry_point("enhance_query")
141
 
142
  # Compile the graph
143
  return workflow.compile()
requirements.txt CHANGED
@@ -273,3 +273,6 @@ xxhash==3.5.0
273
  yarl==1.18.3
274
  zipp==3.21.0
275
  zstandard==0.23.0
 
 
 
 
273
  yarl==1.18.3
274
  zipp==3.21.0
275
  zstandard==0.23.0
276
+ tf-keras
277
+ # pathway[all]
278
+ # langchain-google-genai