Arpit-Bansal commited on
Commit
975d8af
·
1 Parent(s): f925355

Updated Query Classifier

Browse files
Files changed (2) hide show
  1. langgraph_agent.py +30 -20
  2. main.py +0 -1
langgraph_agent.py CHANGED
@@ -1,11 +1,9 @@
1
- # Import LangGraph components
2
  from langgraph.graph import StateGraph, END
3
- from typing import TypedDict, List, Dict, Any, Annotated, Union, Literal
4
- import operator
5
- from pydantic import BaseModel, Field
6
  from agent import LLMChain, PromptTemplate, llm, DOCUMENT_DIR, load_documents, split_documents, CHROMA_PATH, load_vectordb, create_and_store_embeddings
7
  import os
8
- # Define state schema
 
9
  class AgentState(TypedDict):
10
  query: str
11
  previous_conversation: str
@@ -14,18 +12,36 @@ class AgentState(TypedDict):
14
  context: List[str]
15
  response: str
16
 
17
- # Define tools and nodes for the LangGraph
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def query_classifier(state: AgentState) -> AgentState:
20
- """Determine if the query requires RAG retrieval."""
21
- query_lower = state["query"].lower()
22
- rag_keywords = [
23
- "scheme", "schemes", "program", "programs", "policy", "policies",
24
- "public health engineering", "phe", "public health", "government",
25
- "benefit", "financial", "assistance", "aid", "initiative"
26
- ]
 
 
 
 
27
 
28
- state["requires_rag"] = any(keyword in query_lower for keyword in rag_keywords)
 
29
  return state
30
 
31
  def retrieve_documents(state: AgentState) -> AgentState:
@@ -183,18 +199,12 @@ def agent_with_db():
183
  "requires_rag": False,
184
  "context": [],
185
  "response": "",
186
- # "style": style
187
  }
188
- print("Initial state:", initial_state)
189
 
190
- # Run the workflow
191
  final_state = self.workflow.invoke(initial_state)
192
- print("Final state:", final_state)
193
 
194
- # Update conversation history with response
195
  self.conversation_history += f"Assistant: {final_state['response']}\n"
196
 
197
- # Return in the expected format
198
  return {"result": final_state["response"]}
199
 
200
  return HealthAgent(agent_workflow)
 
 
1
  from langgraph.graph import StateGraph, END
2
+ from typing import TypedDict, List, Dict, Any
 
 
3
  from agent import LLMChain, PromptTemplate, llm, DOCUMENT_DIR, load_documents, split_documents, CHROMA_PATH, load_vectordb, create_and_store_embeddings
4
  import os
5
+
6
+ # state schema
7
  class AgentState(TypedDict):
8
  query: str
9
  previous_conversation: str
 
12
  context: List[str]
13
  response: str
14
 
15
+
16
+
17
+ # def query_classifier(state: AgentState) -> AgentState:
18
+ # """Determine if the query requires RAG retrieval based on keywords.
19
+ # Is not continued anymore, will be removed in future."""
20
+ # query_lower = state["query"].lower()
21
+ # rag_keywords = [
22
+ # "scheme", "schemes", "program", "programs", "policy", "policies",
23
+ # "public health engineering", "phe", "public health", "government",
24
+ # "benefit", "financial", "assistance", "aid", "initiative", "yojana",
25
+ # ]
26
+
27
+ # state["requires_rag"] = any(keyword in query_lower for keyword in rag_keywords)
28
+ # return state
29
 
30
  def query_classifier(state: AgentState) -> AgentState:
31
+ """Updated classifier to use LLM for intent classification."""
32
+
33
+ query = state["query"]
34
+
35
+ # Then classify intent
36
+ classification_prompt = f"""
37
+ Answer with only 'Yes' or 'No'.
38
+ Classify if this query is asking about government schemes, policies, or benefits.
39
+ The language may not be English, So first detect the language. and understand the query.:
40
+ Query: {query}
41
+ Remember Answer with only 'Yes' or 'No'."""
42
 
43
+ result = llm.predict(classification_prompt)
44
+ state["requires_rag"] = "yes" in result.lower()
45
  return state
46
 
47
  def retrieve_documents(state: AgentState) -> AgentState:
 
199
  "requires_rag": False,
200
  "context": [],
201
  "response": "",
 
202
  }
 
203
 
 
204
  final_state = self.workflow.invoke(initial_state)
 
205
 
 
206
  self.conversation_history += f"Assistant: {final_state['response']}\n"
207
 
 
208
  return {"result": final_state["response"]}
209
 
210
  return HealthAgent(agent_workflow)
main.py CHANGED
@@ -61,7 +61,6 @@ async def retrieve(request:request, url:Request):
61
  if origin is None:
62
  origin = url.headers.get('referer')
63
  print("origin: ", origin)
64
- print("response: ", response)
65
  return {"response": response["result"]}
66
 
67
  except Exception as e:
 
61
  if origin is None:
62
  origin = url.headers.get('referer')
63
  print("origin: ", origin)
 
64
  return {"response": response["result"]}
65
 
66
  except Exception as e: