Spaces:
Sleeping
Sleeping
Update medical_rag.py
Browse files- medical_rag.py +61 -8
medical_rag.py
CHANGED
@@ -104,55 +104,108 @@ def create_medical_prompt(question, passage):
|
|
104 |
@app.post("/api/chat")
|
105 |
async def chat_endpoint(chat_query: ChatQuery):
|
106 |
try:
|
|
|
|
|
|
|
|
|
107 |
query_text = chat_query.query
|
108 |
-
language_code = chat_query.language_code
|
109 |
if language_code == 0:
|
|
|
110 |
query_text = translate_text(query_text, 'ar_to_en')
|
|
|
111 |
|
112 |
-
# Generate embeddings
|
|
|
113 |
query_embedding = embed_query_text(query_text)
|
114 |
-
|
|
|
|
|
|
|
115 |
embeddings_data = load_embeddings()
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
117 |
initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
|
|
|
|
|
118 |
document_ids = [doc_id for doc_id, _ in initial_results]
|
|
|
|
|
|
|
|
|
|
|
119 |
document_texts = retrieve_document_texts(document_ids, folder_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
-
# Rerank documents with cross-encoder
|
122 |
cross_encoder = models['cross_encoder']
|
123 |
scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
|
124 |
scored_documents = list(zip(scores, document_ids, document_texts))
|
125 |
scored_documents.sort(key=lambda x: x[0], reverse=True)
|
|
|
|
|
|
|
126 |
|
127 |
-
# Extract relevant portions
|
|
|
128 |
relevant_portions = extract_relevant_portions(document_texts, query_text)
|
|
|
|
|
129 |
flattened_relevant_portions = []
|
130 |
for doc_id, portions in relevant_portions.items():
|
131 |
flattened_relevant_portions.extend(portions)
|
132 |
|
133 |
combined_parts = " ".join(flattened_relevant_portions)
|
|
|
|
|
|
|
|
|
134 |
entities = extract_entities(query_text)
|
|
|
|
|
135 |
passage = enhance_passage_with_entities(combined_parts, entities)
|
|
|
136 |
|
137 |
-
#
|
|
|
138 |
prompt = create_medical_prompt(query_text, passage)
|
|
|
|
|
|
|
139 |
answer = get_completion(prompt)
|
|
|
140 |
|
|
|
141 |
final_answer = answer.strip()
|
142 |
if language_code == 0:
|
|
|
143 |
final_answer = translate_text(final_answer, 'en_to_ar')
|
|
|
144 |
|
145 |
if not final_answer:
|
146 |
final_answer = "Sorry, I can't help with that."
|
147 |
-
|
|
|
|
|
148 |
return {
|
149 |
"response": f"I hope this answers your question: {final_answer}",
|
150 |
"success": True
|
151 |
}
|
152 |
|
153 |
except HTTPException as e:
|
|
|
154 |
raise e
|
155 |
except Exception as e:
|
|
|
|
|
156 |
raise HTTPException(status_code=500, detail=str(e))
|
157 |
|
158 |
# Initialize medical models when this module is imported
|
|
|
104 |
@app.post("/api/chat")
|
105 |
async def chat_endpoint(chat_query: ChatQuery):
|
106 |
try:
|
107 |
+
print("\n=== STARTING CHAT REQUEST PROCESSING ===")
|
108 |
+
print(f"Initial query: {chat_query.query} (language_code: {chat_query.language_code})")
|
109 |
+
|
110 |
+
# Step 1: Handle translation if needed
|
111 |
query_text = chat_query.query
|
112 |
+
language_code = chat_query.language_code
|
113 |
if language_code == 0:
|
114 |
+
print("Translating from Arabic to English...")
|
115 |
query_text = translate_text(query_text, 'ar_to_en')
|
116 |
+
print(f"Translated query: {query_text}")
|
117 |
|
118 |
+
# Step 2: Generate embeddings
|
119 |
+
print("\nGenerating query embeddings...")
|
120 |
query_embedding = embed_query_text(query_text)
|
121 |
+
print(f"Embedding generated. Shape: {query_embedding.shape}")
|
122 |
+
|
123 |
+
# Step 3: Load embeddings and query them
|
124 |
+
print("\nLoading document embeddings...")
|
125 |
embeddings_data = load_embeddings()
|
126 |
+
if not embeddings_data:
|
127 |
+
raise HTTPException(status_code=500, detail="Failed to load embeddings data")
|
128 |
+
print(f"Loaded embeddings for {len(embeddings_data)} documents")
|
129 |
+
|
130 |
+
print("\nQuerying embeddings...")
|
131 |
+
n_results = 5
|
132 |
initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
|
133 |
+
print(f"Initial results: {initial_results}")
|
134 |
+
|
135 |
document_ids = [doc_id for doc_id, _ in initial_results]
|
136 |
+
print(f"Document IDs to retrieve: {document_ids}")
|
137 |
+
|
138 |
+
# Step 4: Retrieve document texts
|
139 |
+
print("\nRetrieving document texts...")
|
140 |
+
folder_path = 'downloaded_articles/downloaded_articles'
|
141 |
document_texts = retrieve_document_texts(document_ids, folder_path)
|
142 |
+
print(f"Retrieved {len(document_texts)} documents")
|
143 |
+
|
144 |
+
# Step 5: Rerank documents
|
145 |
+
print("\nReranking documents...")
|
146 |
+
if 'cross_encoder' not in models:
|
147 |
+
raise HTTPException(status_code=500, detail="Cross-encoder model not loaded")
|
148 |
|
|
|
149 |
cross_encoder = models['cross_encoder']
|
150 |
scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
|
151 |
scored_documents = list(zip(scores, document_ids, document_texts))
|
152 |
scored_documents.sort(key=lambda x: x[0], reverse=True)
|
153 |
+
print("Top 3 reranked documents:")
|
154 |
+
for i, (score, doc_id, _) in enumerate(scored_documents[:3]):
|
155 |
+
print(f"{i+1}. Doc {doc_id} (score: {score:.4f})")
|
156 |
|
157 |
+
# Step 6: Extract relevant portions
|
158 |
+
print("\nExtracting relevant portions...")
|
159 |
relevant_portions = extract_relevant_portions(document_texts, query_text)
|
160 |
+
print(f"Found relevant portions in {len(relevant_portions)} documents")
|
161 |
+
|
162 |
flattened_relevant_portions = []
|
163 |
for doc_id, portions in relevant_portions.items():
|
164 |
flattened_relevant_portions.extend(portions)
|
165 |
|
166 |
combined_parts = " ".join(flattened_relevant_portions)
|
167 |
+
print(f"Combined relevant text length: {len(combined_parts)} characters")
|
168 |
+
|
169 |
+
# Step 7: Extract and enhance with entities
|
170 |
+
print("\nExtracting entities...")
|
171 |
entities = extract_entities(query_text)
|
172 |
+
print(f"Found entities: {entities}")
|
173 |
+
|
174 |
passage = enhance_passage_with_entities(combined_parts, entities)
|
175 |
+
print(f"Enhanced passage length: {len(passage)} characters")
|
176 |
|
177 |
+
# Step 8: Generate response
|
178 |
+
print("\nCreating prompt...")
|
179 |
prompt = create_medical_prompt(query_text, passage)
|
180 |
+
print(f"Prompt length: {len(prompt)} characters")
|
181 |
+
|
182 |
+
print("\nGetting completion from DeepSeek...")
|
183 |
answer = get_completion(prompt)
|
184 |
+
print(f"Raw answer received: {answer[:200]}...") # Print first 200 chars
|
185 |
|
186 |
+
# Step 9: Final processing
|
187 |
final_answer = answer.strip()
|
188 |
if language_code == 0:
|
189 |
+
print("\nTranslating answer to Arabic...")
|
190 |
final_answer = translate_text(final_answer, 'en_to_ar')
|
191 |
+
print(f"Translated answer: {final_answer[:200]}...")
|
192 |
|
193 |
if not final_answer:
|
194 |
final_answer = "Sorry, I can't help with that."
|
195 |
+
print("Warning: Empty answer received")
|
196 |
+
|
197 |
+
print("\n=== REQUEST PROCESSING COMPLETE ===")
|
198 |
return {
|
199 |
"response": f"I hope this answers your question: {final_answer}",
|
200 |
"success": True
|
201 |
}
|
202 |
|
203 |
except HTTPException as e:
|
204 |
+
print(f"\n!!! HTTPException: {e.detail}")
|
205 |
raise e
|
206 |
except Exception as e:
|
207 |
+
print(f"\n!!! Unexpected error: {str(e)}")
|
208 |
+
print(f"Error type: {type(e).__name__}")
|
209 |
raise HTTPException(status_code=500, detail=str(e))
|
210 |
|
211 |
# Initialize medical models when this module is imported
|