Ali2206 commited on
Commit
32e4e6a
·
verified ·
1 Parent(s): c4e7e4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -348
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import json
3
  import shutil
@@ -24,18 +25,24 @@ tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
24
  file_cache_dir = os.path.join(persistent_dir, "cache")
25
  report_dir = os.path.join(persistent_dir, "reports")
26
 
 
27
  for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
28
  os.makedirs(d, exist_ok=True)
29
 
 
30
  os.environ["HF_HOME"] = model_cache_dir
31
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
 
32
 
 
33
  current_dir = os.path.dirname(os.path.abspath(__file__))
34
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
35
  sys.path.insert(0, src_path)
36
 
 
37
  from txagent.txagent import TxAgent
38
 
 
39
  MAX_MODEL_TOKENS = 131072
40
  MAX_NEW_TOKENS = 4096
41
  MAX_CHUNK_TOKENS = 8192
@@ -43,9 +50,12 @@ BATCH_SIZE = 1
43
  PROMPT_OVERHEAD = 300
44
  SAFE_SLEEP = 0.5
45
 
46
- app = FastAPI(title="Clinical Patient Support System API",
47
- description="API for analyzing and summarizing unstructured medical files",
48
- version="1.0.0")
 
 
 
49
 
50
  # CORS configuration
51
  app.add_middleware(
@@ -62,107 +72,17 @@ agent = None
62
  @app.on_event("startup")
63
  async def startup_event():
64
  global agent
65
- agent = init_agent()
66
-
67
- def estimate_tokens(text: str) -> int:
68
- return len(text) // 4 + 1
69
-
70
- def clean_response(text: str) -> str:
71
- text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
72
- text = re.sub(r"\n{3,}", "\n\n", text)
73
- return text.strip()
74
-
75
- def remove_duplicate_paragraphs(text: str) -> str:
76
- paragraphs = text.strip().split("\n\n")
77
- seen = set()
78
- unique_paragraphs = []
79
- for p in paragraphs:
80
- clean_p = p.strip()
81
- if clean_p and clean_p not in seen:
82
- unique_paragraphs.append(clean_p)
83
- seen.add(clean_p)
84
- return "\n\n".join(unique_paragraphs)
85
-
86
- def extract_text_from_excel(path: str) -> str:
87
- all_text = []
88
- xls = pd.ExcelFile(path)
89
- for sheet_name in xls.sheet_names:
90
- try:
91
- df = xls.parse(sheet_name).astype(str).fillna("")
92
- except Exception:
93
- continue
94
- for _, row in df.iterrows():
95
- non_empty = [cell.strip() for cell in row if cell.strip()]
96
- if len(non_empty) >= 2:
97
- text_line = " | ".join(non_empty)
98
- if len(text_line) > 15:
99
- all_text.append(f"[{sheet_name}] {text_line}")
100
- return "\n".join(all_text)
101
-
102
- def extract_text_from_csv(path: str) -> str:
103
- all_text = []
104
  try:
105
- df = pd.read_csv(path).astype(str).fillna("")
106
- except Exception:
107
- return ""
108
- for _, row in df.iterrows():
109
- non_empty = [cell.strip() for cell in row if cell.strip()]
110
- if len(non_empty) >= 2:
111
- text_line = " | ".join(non_empty)
112
- if len(text_line) > 15:
113
- all_text.append(text_line)
114
- return "\n".join(all_text)
115
-
116
- def extract_text_from_pdf(path: str) -> str:
117
- import logging
118
- logging.getLogger("pdfminer").setLevel(logging.ERROR)
119
- all_text = []
120
- try:
121
- with pdfplumber.open(path) as pdf:
122
- for page in pdf.pages:
123
- text = page.extract_text()
124
- if text:
125
- all_text.append(text.strip())
126
- except Exception:
127
- return ""
128
- return "\n".join(all_text)
129
-
130
- def extract_text(file_path: str) -> str:
131
- if file_path.endswith(".xlsx"):
132
- return extract_text_from_excel(file_path)
133
- elif file_path.endswith(".csv"):
134
- return extract_text_from_csv(file_path)
135
- elif file_path.endswith(".pdf"):
136
- return extract_text_from_pdf(file_path)
137
- else:
138
- return ""
139
-
140
- def split_text(text: str, max_tokens=MAX_CHUNK_TOKENS) -> List[str]:
141
- effective_limit = max_tokens - PROMPT_OVERHEAD
142
- chunks, current, current_tokens = [], [], 0
143
- for line in text.split("\n"):
144
- tokens = estimate_tokens(line)
145
- if current_tokens + tokens > effective_limit:
146
- if current:
147
- chunks.append("\n".join(current))
148
- current, current_tokens = [line], tokens
149
- else:
150
- current.append(line)
151
- current_tokens += tokens
152
- if current:
153
- chunks.append("\n".join(current))
154
- return chunks
155
-
156
- def batch_chunks(chunks: List[str], batch_size: int = BATCH_SIZE) -> List[List[str]]:
157
- return [chunks[i:i+batch_size] for i in range(0, len(chunks), batch_size)]
158
-
159
- def build_prompt(chunk: str) -> str:
160
- return f"""### Unstructured Clinical Records\n\nAnalyze the clinical notes below and summarize with:\n- Diagnostic Patterns\n- Medication Issues\n- Missed Opportunities\n- Inconsistencies\n- Follow-up Recommendations\n\n---\n\n{chunk}\n\n---\nRespond concisely in bullet points with clinical reasoning."""
161
 
162
  def init_agent() -> TxAgent:
 
163
  tool_path = os.path.join(tool_cache_dir, "new_tool.json")
164
  if not os.path.exists(tool_path):
165
  shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
 
166
  agent = TxAgent(
167
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
168
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
@@ -175,229 +95,63 @@ def init_agent() -> TxAgent:
175
  agent.init_model()
176
  return agent
177
 
178
- def analyze_batches(agent, batches: List[List[str]]) -> List[str]:
179
- results = []
180
- for batch in batches:
181
- prompt = "\n\n".join(build_prompt(chunk) for chunk in batch)
182
- try:
183
- batch_response = ""
184
- for r in agent.run_gradio_chat(
185
- message=prompt,
186
- history=[],
187
- temperature=0.0,
188
- max_new_tokens=MAX_NEW_TOKENS,
189
- max_token=MAX_MODEL_TOKENS,
190
- call_agent=False,
191
- conversation=[]
192
- ):
193
- if isinstance(r, str):
194
- batch_response += r
195
- elif isinstance(r, list):
196
- for m in r:
197
- if hasattr(m, "content"):
198
- batch_response += m.content
199
- elif hasattr(r, "content"):
200
- batch_response += r.content
201
- results.append(clean_response(batch_response))
202
- time.sleep(SAFE_SLEEP)
203
- except Exception as e:
204
- results.append(f"❌ Batch failed: {str(e)}")
205
- time.sleep(SAFE_SLEEP * 2)
206
- torch.cuda.empty_cache()
207
- gc.collect()
208
- return results
209
-
210
- def generate_final_summary(agent, combined: str) -> str:
211
- combined = remove_duplicate_paragraphs(combined)
212
- final_prompt = f"""
213
- You are an expert clinical summarizer. Analyze the following summaries carefully and generate a **single final concise structured medical report**, avoiding any repetition or redundancy.
214
- Summaries:
215
- {combined}
216
- Respond with:
217
- - Diagnostic Patterns
218
- - Medication Issues
219
- - Missed Opportunities
220
- - Inconsistencies
221
- - Follow-up Recommendations
222
- Avoid repeating the same points multiple times.
223
- """.strip()
224
-
225
- final_response = ""
226
- for r in agent.run_gradio_chat(
227
- message=final_prompt,
228
- history=[],
229
- temperature=0.0,
230
- max_new_tokens=MAX_NEW_TOKENS,
231
- max_token=MAX_MODEL_TOKENS,
232
- call_agent=False,
233
- conversation=[]
234
- ):
235
- if isinstance(r, str):
236
- final_response += r
237
- elif isinstance(r, list):
238
- for m in r:
239
- if hasattr(m, "content"):
240
- final_response += m.content
241
- elif hasattr(r, "content"):
242
- final_response += r.content
243
-
244
- final_response = clean_response(final_response)
245
- final_response = remove_duplicate_paragraphs(final_response)
246
- return final_response
247
-
248
- def remove_non_ascii(text):
249
- return ''.join(c for c in text if ord(c) < 256)
250
-
251
- def generate_pdf_report_with_charts(summary: str, report_path: str, detailed_batches: List[str] = None):
252
- chart_dir = os.path.join(os.path.dirname(report_path), "charts")
253
- os.makedirs(chart_dir, exist_ok=True)
254
-
255
- # Prepare static data
256
- categories = ['Diagnostics', 'Medications', 'Missed', 'Inconsistencies', 'Follow-up']
257
- values = [4, 2, 3, 1, 5]
258
-
259
- # === Static Charts ===
260
- chart_paths = []
261
-
262
- def save_chart(fig_func, filename):
263
- path = os.path.join(chart_dir, filename)
264
- fig_func()
265
- plt.tight_layout()
266
- plt.savefig(path)
267
- plt.close()
268
- chart_paths.append((filename.split('.')[0].replace('_', ' ').title(), path))
269
-
270
- save_chart(lambda: plt.bar(categories, values), "bar_chart.png")
271
- save_chart(lambda: plt.pie(values, labels=categories, autopct='%1.1f%%'), "pie_chart.png")
272
- save_chart(lambda: plt.plot(categories, values, marker='o'), "trend_chart.png")
273
- save_chart(lambda: plt.barh(categories, values), "horizontal_bar_chart.png")
274
-
275
- # Radar chart
276
- import numpy as np
277
- labels = np.array(categories)
278
- stats = np.array(values)
279
- angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
280
- stats = np.concatenate((stats, [stats[0]]))
281
- angles += angles[:1]
282
- fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
283
- ax.plot(angles, stats, marker='o')
284
- ax.fill(angles, stats, alpha=0.25)
285
- ax.set_yticklabels([])
286
- ax.set_xticks(angles[:-1])
287
- ax.set_xticklabels(labels)
288
- ax.set_title('Radar Chart: Clinical Focus')
289
- radar_path = os.path.join(chart_dir, "radar_chart.png")
290
- plt.tight_layout()
291
- plt.savefig(radar_path)
292
- plt.close()
293
- chart_paths.append(("Radar Chart: Clinical Focus", radar_path))
294
-
295
- # === Dynamic Chart: Drug Frequency ===
296
- drug_counter = {}
297
- if detailed_batches:
298
- for batch in detailed_batches:
299
- lines = batch.split("\n")
300
- for line in lines:
301
- match = re.search(r"(?i)medication[s]?:\s*(.+)", line)
302
- if match:
303
- items = re.split(r"[,;]", match.group(1))
304
- for item in items:
305
- drug = item.strip().title()
306
- if len(drug) > 2:
307
- drug_counter[drug] = drug_counter.get(drug, 0) + 1
308
-
309
- if drug_counter:
310
- drugs, freqs = zip(*sorted(drug_counter.items(), key=lambda x: x[1], reverse=True)[:10])
311
- plt.figure(figsize=(6, 4))
312
- plt.bar(drugs, freqs)
313
- plt.xticks(rotation=45, ha='right')
314
- plt.title('Top Medications Frequency')
315
- drug_chart_path = os.path.join(chart_dir, "drug_frequency_chart.png")
316
- plt.tight_layout()
317
- plt.savefig(drug_chart_path)
318
- plt.close()
319
- chart_paths.append(("Top Medications Frequency", drug_chart_path))
320
-
321
- # === PDF ===
322
- pdf_path = report_path.replace('.md', '.pdf')
323
- pdf = FPDF()
324
- pdf.set_auto_page_break(auto=True, margin=20)
325
-
326
- def add_section_title(pdf, title):
327
- pdf.set_fill_color(230, 230, 230)
328
- pdf.set_font("Arial", 'B', 14)
329
- pdf.cell(0, 10, remove_non_ascii(title), ln=True, fill=True)
330
- pdf.ln(3)
331
-
332
- def add_footer(pdf):
333
- pdf.set_y(-15)
334
- pdf.set_font('Arial', 'I', 8)
335
- pdf.set_text_color(150, 150, 150)
336
- pdf.cell(0, 10, f"Page {pdf.page_no()}", align='C')
337
-
338
- # Title Page
339
- pdf.add_page()
340
- pdf.set_font("Arial", 'B', 26)
341
- pdf.set_text_color(0, 70, 140)
342
- pdf.cell(0, 20, remove_non_ascii("Final Medical Report"), ln=True, align='C')
343
- pdf.set_text_color(0, 0, 0)
344
- pdf.set_font("Arial", '', 13)
345
- pdf.cell(0, 10, datetime.now().strftime("Generated on %B %d, %Y at %H:%M"), ln=True, align='C')
346
- pdf.ln(15)
347
- pdf.set_font("Arial", '', 11)
348
- pdf.set_fill_color(245, 245, 245)
349
- pdf.multi_cell(0, 9, remove_non_ascii(
350
- "This report contains a professional summary of clinical observations, potential inconsistencies, and follow-up recommendations based on the uploaded medical document."
351
- ), border=1, fill=True, align="J")
352
- add_footer(pdf)
353
-
354
- # Final Summary
355
- pdf.add_page()
356
- add_section_title(pdf, "Final Summary")
357
- pdf.set_font("Arial", '', 11)
358
- for line in summary.split("\n"):
359
- clean_line = remove_non_ascii(line.strip())
360
- if clean_line:
361
- pdf.multi_cell(0, 8, txt=clean_line)
362
- add_footer(pdf)
363
 
364
- # Charts Section
365
- pdf.add_page()
366
- add_section_title(pdf, "Statistical Overview")
367
- for title, path in chart_paths:
368
- pdf.set_font("Arial", 'B', 12)
369
- pdf.cell(0, 9, remove_non_ascii(title), ln=True)
370
- pdf.image(path, w=170)
371
- pdf.ln(6)
372
- add_footer(pdf)
373
 
374
- # Detailed Tool Outputs
375
- if detailed_batches:
376
- pdf.add_page()
377
- add_section_title(pdf, "Detailed Tool Insights")
378
- for idx, detail in enumerate(detailed_batches):
379
- pdf.set_font("Arial", 'B', 12)
380
- pdf.cell(0, 9, remove_non_ascii(f"Tool Output #{idx + 1}"), ln=True)
381
- pdf.set_font("Arial", '', 11)
382
- for line in remove_non_ascii(detail).split("\n"):
383
- pdf.multi_cell(0, 8, txt=line.strip())
384
- pdf.ln(3)
385
- add_footer(pdf)
 
 
 
 
 
 
 
386
 
387
- pdf.output(pdf_path)
388
- return pdf_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
- @app.post("/analyze", summary="Analyze medical document", response_description="Returns analysis results")
 
391
  async def analyze_document(file: UploadFile = File(...)):
392
- """
393
- Analyze a medical document (PDF, Excel, or CSV) and return a structured analysis.
394
-
395
- Args:
396
- file: The medical document to analyze (PDF, Excel, or CSV format)
397
-
398
- Returns:
399
- JSONResponse: Contains analysis results and report download path
400
- """
401
  start_time = time.time()
402
 
403
  try:
@@ -413,50 +167,40 @@ async def analyze_document(file: UploadFile = File(...)):
413
  chunks = split_text(extracted)
414
  batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
415
  batch_results = analyze_batches(agent, batches)
416
- all_tool_outputs = batch_results.copy()
417
- valid = [res for res in batch_results if not res.startswith("❌")]
418
-
419
- if not valid:
420
  raise HTTPException(status_code=400, detail="No valid analysis results were generated")
421
 
422
- summary = generate_final_summary(agent, "\n\n".join(valid))
423
 
424
  # Generate report files
425
  report_filename = f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
426
  report_path = os.path.join(report_dir, f"{report_filename}.md")
427
  with open(report_path, 'w', encoding='utf-8') as f:
428
- f.write(f"# Final Medical Report\n\n{summary}")
429
 
430
- pdf_path = generate_pdf_report_with_charts(summary, report_path, detailed_batches=all_tool_outputs)
431
-
432
- end_time = time.time()
433
- elapsed_time = end_time - start_time
434
 
435
  # Clean up temp file
436
  os.remove(temp_path)
437
 
438
  return JSONResponse({
439
  "status": "success",
440
- "summary": summary,
441
  "report_path": f"/reports/{os.path.basename(pdf_path)}",
442
- "processing_time": f"{elapsed_time:.2f} seconds",
443
- "detailed_outputs": all_tool_outputs
444
  })
445
 
 
 
446
  except Exception as e:
447
  raise HTTPException(status_code=500, detail=str(e))
448
 
449
- @app.get("/reports/{filename}", response_class=FileResponse)
450
  async def download_report(filename: str):
451
- """
452
- Download a generated report PDF file.
453
-
454
- Args:
455
- filename: The name of the report file to download
456
-
457
- Returns:
458
- FileResponse: The PDF file for download
459
- """
460
  file_path = os.path.join(report_dir, filename)
461
  if not os.path.exists(file_path):
462
  raise HTTPException(status_code=404, detail="Report not found")
@@ -464,20 +208,14 @@ async def download_report(filename: str):
464
 
465
  @app.get("/status")
466
  async def service_status():
467
- """
468
- Check the service status and version information.
469
-
470
- Returns:
471
- JSONResponse: Service status information
472
- """
473
- return JSONResponse({
474
  "status": "running",
475
  "version": "1.0.0",
476
  "model": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
477
- "rag_model": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
478
  "max_tokens": MAX_MODEL_TOKENS,
479
  "supported_file_types": [".pdf", ".xlsx", ".csv"]
480
- })
481
 
482
  if __name__ == "__main__":
483
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import sys
2
  import os
3
  import json
4
  import shutil
 
25
  file_cache_dir = os.path.join(persistent_dir, "cache")
26
  report_dir = os.path.join(persistent_dir, "reports")
27
 
28
+ # Create directories if they don't exist
29
  for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
30
  os.makedirs(d, exist_ok=True)
31
 
32
+ # Set environment variables
33
  os.environ["HF_HOME"] = model_cache_dir
34
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
35
+ os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" # Fix for matplotlib permission issues
36
 
37
+ # Set up Python path
38
  current_dir = os.path.dirname(os.path.abspath(__file__))
39
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
40
  sys.path.insert(0, src_path)
41
 
42
+ # Import TxAgent after setting up paths
43
  from txagent.txagent import TxAgent
44
 
45
+ # Constants
46
  MAX_MODEL_TOKENS = 131072
47
  MAX_NEW_TOKENS = 4096
48
  MAX_CHUNK_TOKENS = 8192
 
50
  PROMPT_OVERHEAD = 300
51
  SAFE_SLEEP = 0.5
52
 
53
+ # Initialize FastAPI app
54
+ app = FastAPI(
55
+ title="Clinical Patient Support System API",
56
+ description="API for analyzing and summarizing unstructured medical files",
57
+ version="1.0.0"
58
+ )
59
 
60
  # CORS configuration
61
  app.add_middleware(
 
72
  @app.on_event("startup")
73
  async def startup_event():
74
  global agent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  try:
76
+ agent = init_agent()
77
+ except Exception as e:
78
+ raise RuntimeError(f"Failed to initialize agent: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  def init_agent() -> TxAgent:
81
+ """Initialize and return the TxAgent instance."""
82
  tool_path = os.path.join(tool_cache_dir, "new_tool.json")
83
  if not os.path.exists(tool_path):
84
  shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
85
+
86
  agent = TxAgent(
87
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
88
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
95
  agent.init_model()
96
  return agent
97
 
98
+ # Utility functions (keep your existing functions but add error handling)
99
+ def estimate_tokens(text: str) -> int:
100
+ """Estimate the number of tokens in the given text."""
101
+ return len(text) // 4 + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ def clean_response(text: str) -> str:
104
+ """Clean and format the response text."""
105
+ if not text:
106
+ return ""
107
+ text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
108
+ text = re.sub(r"\n{3,}", "\n\n", text)
109
+ return text.strip()
 
 
110
 
111
+ def extract_text_from_excel(path: str) -> str:
112
+ """Extract text from Excel file."""
113
+ try:
114
+ all_text = []
115
+ xls = pd.ExcelFile(path)
116
+ for sheet_name in xls.sheet_names:
117
+ try:
118
+ df = xls.parse(sheet_name).astype(str).fillna("")
119
+ except Exception:
120
+ continue
121
+ for _, row in df.iterrows():
122
+ non_empty = [cell.strip() for cell in row if cell.strip()]
123
+ if len(non_empty) >= 2:
124
+ text_line = " | ".join(non_empty)
125
+ if len(text_line) > 15:
126
+ all_text.append(f"[{sheet_name}] {text_line}")
127
+ return "\n".join(all_text)
128
+ except Exception as e:
129
+ raise RuntimeError(f"Failed to extract text from Excel: {str(e)}")
130
 
131
+ def extract_text(file_path: str) -> str:
132
+ """Extract text from supported file types."""
133
+ try:
134
+ if file_path.endswith(".xlsx"):
135
+ return extract_text_from_excel(file_path)
136
+ elif file_path.endswith(".csv"):
137
+ df = pd.read_csv(file_path).astype(str).fillna("")
138
+ return "\n".join(
139
+ " | ".join(cell.strip() for cell in row if cell.strip())
140
+ for _, row in df.iterrows()
141
+ if len([cell for cell in row if cell.strip()]) >= 2
142
+ )
143
+ elif file_path.endswith(".pdf"):
144
+ with pdfplumber.open(file_path) as pdf:
145
+ return "\n".join(page.extract_text() or "" for page in pdf.pages)
146
+ else:
147
+ return ""
148
+ except Exception as e:
149
+ raise RuntimeError(f"Failed to extract text from file: {str(e)}")
150
 
151
+ # API endpoints
152
+ @app.post("/analyze")
153
  async def analyze_document(file: UploadFile = File(...)):
154
+ """Analyze a medical document and return results."""
 
 
 
 
 
 
 
 
155
  start_time = time.time()
156
 
157
  try:
 
167
  chunks = split_text(extracted)
168
  batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
169
  batch_results = analyze_batches(agent, batches)
170
+
171
+ valid_results = [res for res in batch_results if not res.startswith("❌")]
172
+ if not valid_results:
 
173
  raise HTTPException(status_code=400, detail="No valid analysis results were generated")
174
 
175
+ final_summary = generate_final_summary(agent, "\n\n".join(valid_results))
176
 
177
  # Generate report files
178
  report_filename = f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
179
  report_path = os.path.join(report_dir, f"{report_filename}.md")
180
  with open(report_path, 'w', encoding='utf-8') as f:
181
+ f.write(f"# Final Medical Report\n\n{final_summary}")
182
 
183
+ pdf_path = generate_pdf_report_with_charts(final_summary, report_path, detailed_batches=batch_results)
 
 
 
184
 
185
  # Clean up temp file
186
  os.remove(temp_path)
187
 
188
  return JSONResponse({
189
  "status": "success",
190
+ "summary": final_summary,
191
  "report_path": f"/reports/{os.path.basename(pdf_path)}",
192
+ "processing_time": f"{time.time() - start_time:.2f} seconds",
193
+ "detailed_outputs": batch_results
194
  })
195
 
196
+ except HTTPException:
197
+ raise
198
  except Exception as e:
199
  raise HTTPException(status_code=500, detail=str(e))
200
 
201
+ @app.get("/reports/{filename}")
202
  async def download_report(filename: str):
203
+ """Download a generated report."""
 
 
 
 
 
 
 
 
204
  file_path = os.path.join(report_dir, filename)
205
  if not os.path.exists(file_path):
206
  raise HTTPException(status_code=404, detail="Report not found")
 
208
 
209
  @app.get("/status")
210
  async def service_status():
211
+ """Check service status."""
212
+ return {
 
 
 
 
 
213
  "status": "running",
214
  "version": "1.0.0",
215
  "model": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
 
216
  "max_tokens": MAX_MODEL_TOKENS,
217
  "supported_file_types": [".pdf", ".xlsx", ".csv"]
218
+ }
219
 
220
  if __name__ == "__main__":
221
  uvicorn.run(app, host="0.0.0.0", port=7860)