Spaces:
Sleeping
Sleeping
Commit
·
975d8af
1
Parent(s):
f925355
Updated Query Classifier
Browse files- langgraph_agent.py +30 -20
- main.py +0 -1
langgraph_agent.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1 |
-
# Import LangGraph components
|
2 |
from langgraph.graph import StateGraph, END
|
3 |
-
from typing import TypedDict, List, Dict, Any
|
4 |
-
import operator
|
5 |
-
from pydantic import BaseModel, Field
|
6 |
from agent import LLMChain, PromptTemplate, llm, DOCUMENT_DIR, load_documents, split_documents, CHROMA_PATH, load_vectordb, create_and_store_embeddings
|
7 |
import os
|
8 |
-
|
|
|
9 |
class AgentState(TypedDict):
|
10 |
query: str
|
11 |
previous_conversation: str
|
@@ -14,18 +12,36 @@ class AgentState(TypedDict):
|
|
14 |
context: List[str]
|
15 |
response: str
|
16 |
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def query_classifier(state: AgentState) -> AgentState:
|
20 |
-
"""
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
|
|
29 |
return state
|
30 |
|
31 |
def retrieve_documents(state: AgentState) -> AgentState:
|
@@ -183,18 +199,12 @@ def agent_with_db():
|
|
183 |
"requires_rag": False,
|
184 |
"context": [],
|
185 |
"response": "",
|
186 |
-
# "style": style
|
187 |
}
|
188 |
-
print("Initial state:", initial_state)
|
189 |
|
190 |
-
# Run the workflow
|
191 |
final_state = self.workflow.invoke(initial_state)
|
192 |
-
print("Final state:", final_state)
|
193 |
|
194 |
-
# Update conversation history with response
|
195 |
self.conversation_history += f"Assistant: {final_state['response']}\n"
|
196 |
|
197 |
-
# Return in the expected format
|
198 |
return {"result": final_state["response"]}
|
199 |
|
200 |
return HealthAgent(agent_workflow)
|
|
|
|
|
1 |
from langgraph.graph import StateGraph, END
|
2 |
+
from typing import TypedDict, List, Dict, Any
|
|
|
|
|
3 |
from agent import LLMChain, PromptTemplate, llm, DOCUMENT_DIR, load_documents, split_documents, CHROMA_PATH, load_vectordb, create_and_store_embeddings
|
4 |
import os
|
5 |
+
|
6 |
+
# state schema
|
7 |
class AgentState(TypedDict):
|
8 |
query: str
|
9 |
previous_conversation: str
|
|
|
12 |
context: List[str]
|
13 |
response: str
|
14 |
|
15 |
+
|
16 |
+
|
17 |
+
# def query_classifier(state: AgentState) -> AgentState:
|
18 |
+
# """Determine if the query requires RAG retrieval based on keywords.
|
19 |
+
# Is not continued anymore, will be removed in future."""
|
20 |
+
# query_lower = state["query"].lower()
|
21 |
+
# rag_keywords = [
|
22 |
+
# "scheme", "schemes", "program", "programs", "policy", "policies",
|
23 |
+
# "public health engineering", "phe", "public health", "government",
|
24 |
+
# "benefit", "financial", "assistance", "aid", "initiative", "yojana",
|
25 |
+
# ]
|
26 |
+
|
27 |
+
# state["requires_rag"] = any(keyword in query_lower for keyword in rag_keywords)
|
28 |
+
# return state
|
29 |
|
30 |
def query_classifier(state: AgentState) -> AgentState:
|
31 |
+
"""Updated classifier to use LLM for intent classification."""
|
32 |
+
|
33 |
+
query = state["query"]
|
34 |
+
|
35 |
+
# Then classify intent
|
36 |
+
classification_prompt = f"""
|
37 |
+
Answer with only 'Yes' or 'No'.
|
38 |
+
Classify if this query is asking about government schemes, policies, or benefits.
|
39 |
+
The language may not be English, So first detect the language. and understand the query.:
|
40 |
+
Query: {query}
|
41 |
+
Remember Answer with only 'Yes' or 'No'."""
|
42 |
|
43 |
+
result = llm.predict(classification_prompt)
|
44 |
+
state["requires_rag"] = "yes" in result.lower()
|
45 |
return state
|
46 |
|
47 |
def retrieve_documents(state: AgentState) -> AgentState:
|
|
|
199 |
"requires_rag": False,
|
200 |
"context": [],
|
201 |
"response": "",
|
|
|
202 |
}
|
|
|
203 |
|
|
|
204 |
final_state = self.workflow.invoke(initial_state)
|
|
|
205 |
|
|
|
206 |
self.conversation_history += f"Assistant: {final_state['response']}\n"
|
207 |
|
|
|
208 |
return {"result": final_state["response"]}
|
209 |
|
210 |
return HealthAgent(agent_workflow)
|
main.py
CHANGED
@@ -61,7 +61,6 @@ async def retrieve(request:request, url:Request):
|
|
61 |
if origin is None:
|
62 |
origin = url.headers.get('referer')
|
63 |
print("origin: ", origin)
|
64 |
-
print("response: ", response)
|
65 |
return {"response": response["result"]}
|
66 |
|
67 |
except Exception as e:
|
|
|
61 |
if origin is None:
|
62 |
origin = url.headers.get('referer')
|
63 |
print("origin: ", origin)
|
|
|
64 |
return {"response": response["result"]}
|
65 |
|
66 |
except Exception as e:
|