37-AN
commited on
Commit
·
b725ad2
1
Parent(s):
6c6cf17
Fix output keys format with wrapper function
Browse files- app/core/agent.py +5 -15
- app/core/memory.py +24 -2
app/core/agent.py
CHANGED
@@ -50,23 +50,13 @@ Assistant:"""
|
|
50 |
|
51 |
# Use the RAG chain to get an answer
|
52 |
response = self.rag_chain({"question": question})
|
|
|
53 |
|
54 |
-
# Extract the answer
|
55 |
-
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
# Create a fallback answer if the expected key is missing
|
60 |
-
answer = "I'm sorry, I encountered an issue processing your request. Let me try a simpler response."
|
61 |
-
else:
|
62 |
-
answer = response["answer"]
|
63 |
-
|
64 |
-
# Handle different variations of source document keys
|
65 |
-
source_docs = []
|
66 |
-
if "source_documents" in response:
|
67 |
-
source_docs = response["source_documents"]
|
68 |
-
elif "sources" in response:
|
69 |
-
source_docs = response["sources"]
|
70 |
|
71 |
# Format source documents for display
|
72 |
sources = []
|
|
|
50 |
|
51 |
# Use the RAG chain to get an answer
|
52 |
response = self.rag_chain({"question": question})
|
53 |
+
logger.info(f"Raw response keys: {list(response.keys())}")
|
54 |
|
55 |
+
# Extract the answer (should now be normalized by our wrapper)
|
56 |
+
answer = response.get("answer", "I couldn't generate a proper response.")
|
57 |
|
58 |
+
# Extract sources (should now be normalized by our wrapper)
|
59 |
+
source_docs = response.get("sources", [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
# Format source documents for display
|
62 |
sources = []
|
app/core/memory.py
CHANGED
@@ -126,7 +126,29 @@ class MemoryManager:
|
|
126 |
return_source_documents=True,
|
127 |
return_generated_question=False,
|
128 |
)
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
except Exception as e:
|
131 |
logger.error(f"Error creating RAG chain: {e}")
|
132 |
|
@@ -138,7 +160,7 @@ class MemoryManager:
|
|
138 |
logger.info(f"Mock chain received query: {inputs.get('question', '')}")
|
139 |
return {
|
140 |
"answer": "I'm having trouble accessing the knowledge base. I can only answer general questions right now.",
|
141 |
-
"
|
142 |
}
|
143 |
|
144 |
return mock_chain
|
|
|
126 |
return_source_documents=True,
|
127 |
return_generated_question=False,
|
128 |
)
|
129 |
+
|
130 |
+
# Create a wrapper function that normalizes the chain output format
|
131 |
+
def normalized_chain(inputs):
|
132 |
+
logger.info("Executing RAG chain with normalizer")
|
133 |
+
try:
|
134 |
+
# Execute the original chain
|
135 |
+
response = chain(inputs)
|
136 |
+
logger.info(f"Original chain output keys: {list(response.keys())}")
|
137 |
+
|
138 |
+
# Create a normalized response
|
139 |
+
normalized = {
|
140 |
+
"answer": response.get("answer", "No answer generated"),
|
141 |
+
"sources": response.get("source_documents", [])
|
142 |
+
}
|
143 |
+
return normalized
|
144 |
+
except Exception as e:
|
145 |
+
logger.error(f"Error in normalized chain: {e}")
|
146 |
+
return {
|
147 |
+
"answer": f"Error processing your query: {str(e)}",
|
148 |
+
"sources": []
|
149 |
+
}
|
150 |
+
|
151 |
+
return normalized_chain
|
152 |
except Exception as e:
|
153 |
logger.error(f"Error creating RAG chain: {e}")
|
154 |
|
|
|
160 |
logger.info(f"Mock chain received query: {inputs.get('question', '')}")
|
161 |
return {
|
162 |
"answer": "I'm having trouble accessing the knowledge base. I can only answer general questions right now.",
|
163 |
+
"sources": []
|
164 |
}
|
165 |
|
166 |
return mock_chain
|