Spaces:
Sleeping
Sleeping
Commit
·
9e67c92
1
Parent(s):
975d8af
enhanced_query
Browse files- generate.py +0 -8
- langgraph_agent.py +22 -16
- requirements.txt +3 -0
generate.py
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
from google import genai
|
2 |
-
from dotenv import load_dotenv
|
3 |
-
from os import getenv
|
4 |
-
|
5 |
-
load_dotenv()
|
6 |
-
|
7 |
-
GEMINI_API_KEY = getenv("GEMINI_API_KEY")
|
8 |
-
from .constants import GEMINI_API_KEY
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
langgraph_agent.py
CHANGED
@@ -12,21 +12,6 @@ class AgentState(TypedDict):
|
|
12 |
context: List[str]
|
13 |
response: str
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
# def query_classifier(state: AgentState) -> AgentState:
|
18 |
-
# """Determine if the query requires RAG retrieval based on keywords.
|
19 |
-
# Is not continued anymore, will be removed in future."""
|
20 |
-
# query_lower = state["query"].lower()
|
21 |
-
# rag_keywords = [
|
22 |
-
# "scheme", "schemes", "program", "programs", "policy", "policies",
|
23 |
-
# "public health engineering", "phe", "public health", "government",
|
24 |
-
# "benefit", "financial", "assistance", "aid", "initiative", "yojana",
|
25 |
-
# ]
|
26 |
-
|
27 |
-
# state["requires_rag"] = any(keyword in query_lower for keyword in rag_keywords)
|
28 |
-
# return state
|
29 |
-
|
30 |
def query_classifier(state: AgentState) -> AgentState:
|
31 |
"""Updated classifier to use LLM for intent classification."""
|
32 |
|
@@ -44,6 +29,25 @@ def query_classifier(state: AgentState) -> AgentState:
|
|
44 |
state["requires_rag"] = "yes" in result.lower()
|
45 |
return state
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
def retrieve_documents(state: AgentState) -> AgentState:
|
48 |
"""Retrieve documents from vector store if needed."""
|
49 |
if state["requires_rag"]:
|
@@ -121,17 +125,19 @@ def create_agent_workflow():
|
|
121 |
workflow = StateGraph(AgentState)
|
122 |
|
123 |
# Add nodes
|
|
|
124 |
workflow.add_node("classifier", query_classifier)
|
125 |
workflow.add_node("retriever", retrieve_documents)
|
126 |
workflow.add_node("responder", generate_response)
|
127 |
|
128 |
# Create edges
|
|
|
129 |
workflow.add_edge("classifier", "retriever")
|
130 |
workflow.add_edge("retriever", "responder")
|
131 |
workflow.add_edge("responder", END)
|
132 |
|
133 |
# Set the entry point
|
134 |
-
workflow.set_entry_point("
|
135 |
|
136 |
# Compile the graph
|
137 |
return workflow.compile()
|
|
|
12 |
context: List[str]
|
13 |
response: str
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def query_classifier(state: AgentState) -> AgentState:
|
16 |
"""Updated classifier to use LLM for intent classification."""
|
17 |
|
|
|
29 |
state["requires_rag"] = "yes" in result.lower()
|
30 |
return state
|
31 |
|
32 |
+
def enhance_query(state:AgentState) -> AgentState:
|
33 |
+
"""Enhance the query with user data and context."""
|
34 |
+
previous_conversation = state.get("previous_conversation", "")
|
35 |
+
user_data = state.get("user_data", {})
|
36 |
+
query = state.get("query", "")
|
37 |
+
|
38 |
+
query_enhancement_prompt = f"""
|
39 |
+
Enhance the following query with user data and previous conversation context so it uses the previous conversation and user data.
|
40 |
+
To be used for generating a more relevant and personalized response.
|
41 |
+
Previous Conversation: {previous_conversation}
|
42 |
+
User Data: {user_data}
|
43 |
+
Current Query: {query}
|
44 |
+
Only write the enhanced query. No other text."""
|
45 |
+
result = llm.predict(query_enhancement_prompt)
|
46 |
+
print("Enhanced query: ", result)
|
47 |
+
state["query"] = result
|
48 |
+
|
49 |
+
return state
|
50 |
+
|
51 |
def retrieve_documents(state: AgentState) -> AgentState:
|
52 |
"""Retrieve documents from vector store if needed."""
|
53 |
if state["requires_rag"]:
|
|
|
125 |
workflow = StateGraph(AgentState)
|
126 |
|
127 |
# Add nodes
|
128 |
+
workflow.add_node("enhance_query", enhance_query)
|
129 |
workflow.add_node("classifier", query_classifier)
|
130 |
workflow.add_node("retriever", retrieve_documents)
|
131 |
workflow.add_node("responder", generate_response)
|
132 |
|
133 |
# Create edges
|
134 |
+
workflow.add_edge("enhance_query", "classifier")
|
135 |
workflow.add_edge("classifier", "retriever")
|
136 |
workflow.add_edge("retriever", "responder")
|
137 |
workflow.add_edge("responder", END)
|
138 |
|
139 |
# Set the entry point
|
140 |
+
workflow.set_entry_point("enhance_query")
|
141 |
|
142 |
# Compile the graph
|
143 |
return workflow.compile()
|
requirements.txt
CHANGED
@@ -273,3 +273,6 @@ xxhash==3.5.0
|
|
273 |
yarl==1.18.3
|
274 |
zipp==3.21.0
|
275 |
zstandard==0.23.0
|
|
|
|
|
|
|
|
273 |
yarl==1.18.3
|
274 |
zipp==3.21.0
|
275 |
zstandard==0.23.0
|
276 |
+
tf-keras
|
277 |
+
# pathway[all]
|
278 |
+
# langchain-google-genai
|