thechaiexperiment commited on
Commit
ddae698
·
verified ·
1 Parent(s): c21e3a0

Update medical_rag.py

Browse files
Files changed (1) hide show
  1. medical_rag.py +61 -8
medical_rag.py CHANGED
@@ -104,55 +104,108 @@ def create_medical_prompt(question, passage):
104
  @app.post("/api/chat")
105
  async def chat_endpoint(chat_query: ChatQuery):
106
  try:
 
 
 
 
107
  query_text = chat_query.query
108
- language_code = chat_query.language_code
109
  if language_code == 0:
 
110
  query_text = translate_text(query_text, 'ar_to_en')
 
111
 
112
- # Generate embeddings and retrieve relevant documents
 
113
  query_embedding = embed_query_text(query_text)
114
- n_results = 5
 
 
 
115
  embeddings_data = load_embeddings()
116
- folder_path = 'downloaded_articles/downloaded_articles'
 
 
 
 
 
117
  initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
 
 
118
  document_ids = [doc_id for doc_id, _ in initial_results]
 
 
 
 
 
119
  document_texts = retrieve_document_texts(document_ids, folder_path)
 
 
 
 
 
 
120
 
121
- # Rerank documents with cross-encoder
122
  cross_encoder = models['cross_encoder']
123
  scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
124
  scored_documents = list(zip(scores, document_ids, document_texts))
125
  scored_documents.sort(key=lambda x: x[0], reverse=True)
 
 
 
126
 
127
- # Extract relevant portions from documents using medical-specific function
 
128
  relevant_portions = extract_relevant_portions(document_texts, query_text)
 
 
129
  flattened_relevant_portions = []
130
  for doc_id, portions in relevant_portions.items():
131
  flattened_relevant_portions.extend(portions)
132
 
133
  combined_parts = " ".join(flattened_relevant_portions)
 
 
 
 
134
  entities = extract_entities(query_text)
 
 
135
  passage = enhance_passage_with_entities(combined_parts, entities)
 
136
 
137
- # Create medical-specific prompt and get completion from DeepSeek
 
138
  prompt = create_medical_prompt(query_text, passage)
 
 
 
139
  answer = get_completion(prompt)
 
140
 
 
141
  final_answer = answer.strip()
142
  if language_code == 0:
 
143
  final_answer = translate_text(final_answer, 'en_to_ar')
 
144
 
145
  if not final_answer:
146
  final_answer = "Sorry, I can't help with that."
147
-
 
 
148
  return {
149
  "response": f"I hope this answers your question: {final_answer}",
150
  "success": True
151
  }
152
 
153
  except HTTPException as e:
 
154
  raise e
155
  except Exception as e:
 
 
156
  raise HTTPException(status_code=500, detail=str(e))
157
 
158
  # Initialize medical models when this module is imported
 
104
  @app.post("/api/chat")
105
  async def chat_endpoint(chat_query: ChatQuery):
106
  try:
107
+ print("\n=== STARTING CHAT REQUEST PROCESSING ===")
108
+ print(f"Initial query: {chat_query.query} (language_code: {chat_query.language_code})")
109
+
110
+ # Step 1: Handle translation if needed
111
  query_text = chat_query.query
112
+ language_code = chat_query.language_code
113
  if language_code == 0:
114
+ print("Translating from Arabic to English...")
115
  query_text = translate_text(query_text, 'ar_to_en')
116
+ print(f"Translated query: {query_text}")
117
 
118
+ # Step 2: Generate embeddings
119
+ print("\nGenerating query embeddings...")
120
  query_embedding = embed_query_text(query_text)
121
+ print(f"Embedding generated. Shape: {query_embedding.shape}")
122
+
123
+ # Step 3: Load embeddings and query them
124
+ print("\nLoading document embeddings...")
125
  embeddings_data = load_embeddings()
126
+ if not embeddings_data:
127
+ raise HTTPException(status_code=500, detail="Failed to load embeddings data")
128
+ print(f"Loaded embeddings for {len(embeddings_data)} documents")
129
+
130
+ print("\nQuerying embeddings...")
131
+ n_results = 5
132
  initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
133
+ print(f"Initial results: {initial_results}")
134
+
135
  document_ids = [doc_id for doc_id, _ in initial_results]
136
+ print(f"Document IDs to retrieve: {document_ids}")
137
+
138
+ # Step 4: Retrieve document texts
139
+ print("\nRetrieving document texts...")
140
+ folder_path = 'downloaded_articles/downloaded_articles'
141
  document_texts = retrieve_document_texts(document_ids, folder_path)
142
+ print(f"Retrieved {len(document_texts)} documents")
143
+
144
+ # Step 5: Rerank documents
145
+ print("\nReranking documents...")
146
+ if 'cross_encoder' not in models:
147
+ raise HTTPException(status_code=500, detail="Cross-encoder model not loaded")
148
 
 
149
  cross_encoder = models['cross_encoder']
150
  scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
151
  scored_documents = list(zip(scores, document_ids, document_texts))
152
  scored_documents.sort(key=lambda x: x[0], reverse=True)
153
+ print("Top 3 reranked documents:")
154
+ for i, (score, doc_id, _) in enumerate(scored_documents[:3]):
155
+ print(f"{i+1}. Doc {doc_id} (score: {score:.4f})")
156
 
157
+ # Step 6: Extract relevant portions
158
+ print("\nExtracting relevant portions...")
159
  relevant_portions = extract_relevant_portions(document_texts, query_text)
160
+ print(f"Found relevant portions in {len(relevant_portions)} documents")
161
+
162
  flattened_relevant_portions = []
163
  for doc_id, portions in relevant_portions.items():
164
  flattened_relevant_portions.extend(portions)
165
 
166
  combined_parts = " ".join(flattened_relevant_portions)
167
+ print(f"Combined relevant text length: {len(combined_parts)} characters")
168
+
169
+ # Step 7: Extract and enhance with entities
170
+ print("\nExtracting entities...")
171
  entities = extract_entities(query_text)
172
+ print(f"Found entities: {entities}")
173
+
174
  passage = enhance_passage_with_entities(combined_parts, entities)
175
+ print(f"Enhanced passage length: {len(passage)} characters")
176
 
177
+ # Step 8: Generate response
178
+ print("\nCreating prompt...")
179
  prompt = create_medical_prompt(query_text, passage)
180
+ print(f"Prompt length: {len(prompt)} characters")
181
+
182
+ print("\nGetting completion from DeepSeek...")
183
  answer = get_completion(prompt)
184
+ print(f"Raw answer received: {answer[:200]}...") # Print first 200 chars
185
 
186
+ # Step 9: Final processing
187
  final_answer = answer.strip()
188
  if language_code == 0:
189
+ print("\nTranslating answer to Arabic...")
190
  final_answer = translate_text(final_answer, 'en_to_ar')
191
+ print(f"Translated answer: {final_answer[:200]}...")
192
 
193
  if not final_answer:
194
  final_answer = "Sorry, I can't help with that."
195
+ print("Warning: Empty answer received")
196
+
197
+ print("\n=== REQUEST PROCESSING COMPLETE ===")
198
  return {
199
  "response": f"I hope this answers your question: {final_answer}",
200
  "success": True
201
  }
202
 
203
  except HTTPException as e:
204
+ print(f"\n!!! HTTPException: {e.detail}")
205
  raise e
206
  except Exception as e:
207
+ print(f"\n!!! Unexpected error: {str(e)}")
208
+ print(f"Error type: {type(e).__name__}")
209
  raise HTTPException(status_code=500, detail=str(e))
210
 
211
  # Initialize medical models when this module is imported