Spaces:
Running
Running
Commit
·
28e1a34
1
Parent(s):
df0b0dd
Modified CustomChain, code-linting
Browse files
agent.py
CHANGED
@@ -28,30 +28,6 @@ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", GEMINI_API_KEY=GEMINI_API
|
|
28 |
|
29 |
|
30 |
print("Models initialized successfully.")
|
31 |
-
def load_documents(directory):
|
32 |
-
loader = DirectoryLoader(
|
33 |
-
directory,
|
34 |
-
glob="**/*.txt",
|
35 |
-
loader_cls=TextLoader
|
36 |
-
)
|
37 |
-
documents = loader.load()
|
38 |
-
|
39 |
-
docx_loader = DirectoryLoader(
|
40 |
-
directory,
|
41 |
-
glob="**/*.docx",
|
42 |
-
loader_cls=UnstructuredWordDocumentLoader,
|
43 |
-
loader_kwargs={"mode": "elements"}
|
44 |
-
)
|
45 |
-
documents.extend(docx_loader.load())
|
46 |
-
print(f"Loaded {len(documents)} documents.")
|
47 |
-
pdf_loader = DirectoryLoader(
|
48 |
-
directory,
|
49 |
-
glob="**/*.pdf",
|
50 |
-
loader_cls=UnstructuredPDFLoader
|
51 |
-
)
|
52 |
-
documents.extend(pdf_loader.load())
|
53 |
-
print(f"Loaded {len(documents)} documents.")
|
54 |
-
return documents
|
55 |
|
56 |
import os
|
57 |
from dotenv import load_dotenv
|
@@ -200,7 +176,7 @@ def create_health_agent(vector_store):
|
|
200 |
|
201 |
@property
|
202 |
def input_keys(self):
|
203 |
-
return ['query', 'previous_conversation']
|
204 |
|
205 |
@property
|
206 |
def output_keys(self):
|
@@ -209,6 +185,7 @@ def create_health_agent(vector_store):
|
|
209 |
def _call(self, inputs):
|
210 |
query = inputs['query']
|
211 |
previous_conversation = inputs.get('previous_conversation', '')
|
|
|
212 |
|
213 |
# Retrieve relevant documents
|
214 |
docs = retriever.get_relevant_documents(query)
|
@@ -218,7 +195,8 @@ def create_health_agent(vector_store):
|
|
218 |
llm_inputs = {
|
219 |
'context': context,
|
220 |
'question': query,
|
221 |
-
'previous_conversation': previous_conversation
|
|
|
222 |
}
|
223 |
|
224 |
# Generate response
|
@@ -229,7 +207,7 @@ def create_health_agent(vector_store):
|
|
229 |
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
230 |
|
231 |
# Create and return the custom chain
|
232 |
-
return CustomRetrievalQA(retriever=retriever, llm_chain=llm_chain)
|
233 |
|
234 |
|
235 |
def agent_with_db():
|
|
|
28 |
|
29 |
|
30 |
print("Models initialized successfully.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
import os
|
33 |
from dotenv import load_dotenv
|
|
|
176 |
|
177 |
@property
|
178 |
def input_keys(self):
|
179 |
+
return ['query', 'previous_conversation', 'user_data']
|
180 |
|
181 |
@property
|
182 |
def output_keys(self):
|
|
|
185 |
def _call(self, inputs):
|
186 |
query = inputs['query']
|
187 |
previous_conversation = inputs.get('previous_conversation', '')
|
188 |
+
user_data = inputs.get('user_data', '')
|
189 |
|
190 |
# Retrieve relevant documents
|
191 |
docs = retriever.get_relevant_documents(query)
|
|
|
195 |
llm_inputs = {
|
196 |
'context': context,
|
197 |
'question': query,
|
198 |
+
'previous_conversation': previous_conversation,
|
199 |
+
'user_data': user_data
|
200 |
}
|
201 |
|
202 |
# Generate response
|
|
|
207 |
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
208 |
|
209 |
# Create and return the custom chain
|
210 |
+
return CustomRetrievalQA(retriever=retriever, llm_chain=llm_chain, user_data=None)
|
211 |
|
212 |
|
213 |
def agent_with_db():
|