Ali2206 commited on
Commit
b90a0eb
·
verified ·
1 Parent(s): 3b7144e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -28
app.py CHANGED
@@ -130,7 +130,8 @@ def init_agent():
130
  enable_checker=True,
131
  step_rag_num=8,
132
  seed=100,
133
- additional_default_tools=[]
 
134
  )
135
  agent.init_model()
136
  return agent
@@ -140,7 +141,7 @@ def create_ui(agent: TxAgent):
140
  gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
141
  gr.Markdown("<h3 style='text-align: center;'>Identify potential oversights in patient care</h3>")
142
 
143
- chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
144
  file_upload = gr.File(
145
  label="Upload Medical Records",
146
  file_types=[".pdf", ".csv", ".xls", ".xlsx"],
@@ -149,23 +150,26 @@ def create_ui(agent: TxAgent):
149
  msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
150
  send_btn = gr.Button("Analyze", variant="primary")
151
  conversation_state = gr.State([])
152
- download_output = gr.File(label="Download Full Report (after tools finish)")
153
 
154
  def analyze_potential_oversights(message: str, history: list, conversation: list, files: list):
155
  start_time = time.time()
156
  try:
157
- history.append({"role": "user", "content": message})
 
158
  yield history, None
159
 
 
160
  extracted_data = ""
161
  file_hash_value = ""
162
- if files:
163
  with ThreadPoolExecutor(max_workers=4) as executor:
164
- futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files if hasattr(f, 'name')]
165
- results = [sanitize_utf8(f.result()) for f in as_completed(futures)]
166
- extracted_data = "\n".join(results)
167
- file_hash_value = file_hash(files[0].name)
168
 
 
169
  analysis_prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
170
  1. List potential missed diagnoses
171
  2. Flag any medication conflicts
@@ -174,10 +178,16 @@ def create_ui(agent: TxAgent):
174
 
175
  Medical Records:\n{extracted_data[:15000]}
176
 
177
- ### Potential Oversights:\n"""
178
 
179
- response = ""
180
- for chunk in agent.run_gradio_chat(
 
 
 
 
 
 
181
  message=analysis_prompt,
182
  history=[],
183
  temperature=0.2,
@@ -185,23 +195,45 @@ Medical Records:\n{extracted_data[:15000]}
185
  max_token=4096,
186
  call_agent=False,
187
  conversation=conversation
188
- ):
189
- partial = ""
190
- if isinstance(chunk, str):
191
- partial = chunk
192
- elif isinstance(chunk, list):
193
- partial = "".join([c.content for c in chunk if hasattr(c, 'content')])
194
- response += partial
195
- history[-1] = {"role": "assistant", "content": response.strip()}
196
- yield history, None
197
-
198
- report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt")
199
- return history, report_path if os.path.exists(report_path) else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  except Exception as e:
202
- history.append({"role": "assistant", "content": f"Analysis failed: {str(e)}"})
203
- return history, None
 
204
 
 
205
  inputs = [msg_input, chatbot, conversation_state, file_upload]
206
  outputs = [chatbot, download_output]
207
  send_btn.click(analyze_potential_oversights, inputs=inputs, outputs=outputs)
@@ -218,11 +250,28 @@ Medical Records:\n{extracted_data[:15000]}
218
  if __name__ == "__main__":
219
  print("Initializing medical analysis agent...")
220
  agent = init_agent()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  print("Launching interface...")
222
  demo = create_ui(agent)
223
- demo.queue().launch(
224
  server_name="0.0.0.0",
225
  server_port=7860,
226
  show_error=True,
227
  share=True
228
- )
 
130
  enable_checker=True,
131
  step_rag_num=8,
132
  seed=100,
133
+ additional_default_tools=[],
134
+ device_map="auto"
135
  )
136
  agent.init_model()
137
  return agent
 
141
  gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
142
  gr.Markdown("<h3 style='text-align: center;'>Identify potential oversights in patient care</h3>")
143
 
144
+ chatbot = gr.Chatbot(label="Analysis", height=600)
145
  file_upload = gr.File(
146
  label="Upload Medical Records",
147
  file_types=[".pdf", ".csv", ".xls", ".xlsx"],
 
150
  msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
151
  send_btn = gr.Button("Analyze", variant="primary")
152
  conversation_state = gr.State([])
153
+ download_output = gr.File(label="Download Full Report")
154
 
155
  def analyze_potential_oversights(message: str, history: list, conversation: list, files: list):
156
  start_time = time.time()
157
  try:
158
+ # Initialize conversation
159
+ history.append((message, "Analyzing records for potential oversights..."))
160
  yield history, None
161
 
162
+ # Process files
163
  extracted_data = ""
164
  file_hash_value = ""
165
+ if files and isinstance(files, list):
166
  with ThreadPoolExecutor(max_workers=4) as executor:
167
+ futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower())
168
+ for f in files if hasattr(f, 'name')]
169
+ extracted_data = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
170
+ file_hash_value = file_hash(files[0].name) if files else ""
171
 
172
+ # Medical oversight analysis prompt
173
  analysis_prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
174
  1. List potential missed diagnoses
175
  2. Flag any medication conflicts
 
178
 
179
  Medical Records:\n{extracted_data[:15000]}
180
 
181
+ Provide ONLY the potential oversights in this format:
182
 
183
+ ### Potential Oversights:
184
+ 1. [Missed diagnosis] - [Evidence from records]
185
+ 2. [Medication issue] - [Supporting data]
186
+ 3. [Assessment gap] - [Relevant findings]"""
187
+
188
+ # Generate and stream response
189
+ full_response = ""
190
+ generator = agent.run_gradio_chat(
191
  message=analysis_prompt,
192
  history=[],
193
  temperature=0.2,
 
195
  max_token=4096,
196
  call_agent=False,
197
  conversation=conversation
198
+ )
199
+
200
+ for update in generator:
201
+ if not update:
202
+ continue
203
+
204
+ if isinstance(update, str):
205
+ full_response += update
206
+ elif isinstance(update, list):
207
+ full_response += "".join([msg.content for msg in update if hasattr(msg, 'content')])
208
+
209
+ # Clean and update the response
210
+ cleaned = full_response.replace("[TOOL_CALLS]", "").strip()
211
+ if cleaned:
212
+ history[-1] = (message, cleaned)
213
+ yield history, None
214
+
215
+ # Final cleaned response
216
+ final_output = full_response.replace("[TOOL_CALLS]", "").strip()
217
+ if not final_output:
218
+ final_output = "No clear oversights identified. Recommend comprehensive review."
219
+
220
+ # Prepare report path if available
221
+ report_path = None
222
+ if file_hash_value:
223
+ possible_report = os.path.join(report_dir, f"{file_hash_value}_report.txt")
224
+ if os.path.exists(possible_report):
225
+ report_path = possible_report
226
+
227
+ history[-1] = (message, final_output)
228
+ print(f"Final analysis:\n{final_output}")
229
+ yield history, report_path
230
 
231
  except Exception as e:
232
+ print(f"Analysis error: {str(e)}")
233
+ history[-1] = (message, f"❌ Analysis failed: {str(e)}")
234
+ yield history, None
235
 
236
+ # UI event handlers
237
  inputs = [msg_input, chatbot, conversation_state, file_upload]
238
  outputs = [chatbot, download_output]
239
  send_btn.click(analyze_potential_oversights, inputs=inputs, outputs=outputs)
 
250
  if __name__ == "__main__":
251
  print("Initializing medical analysis agent...")
252
  agent = init_agent()
253
+
254
+ print("Performing warm-up call...")
255
+ try:
256
+ warm_up = agent.run_gradio_chat(
257
+ message="Warm up",
258
+ history=[],
259
+ temperature=0.1,
260
+ max_new_tokens=10,
261
+ max_token=100,
262
+ call_agent=False,
263
+ conversation=[]
264
+ )
265
+ for _ in warm_up:
266
+ pass
267
+ except Exception as e:
268
+ print(f"Warm-up error: {str(e)}")
269
+
270
  print("Launching interface...")
271
  demo = create_ui(agent)
272
+ demo.queue(concurrency_count=2).launch(
273
  server_name="0.0.0.0",
274
  server_port=7860,
275
  show_error=True,
276
  share=True
277
+ )