Ali2206 commited on
Commit
f126604
·
verified ·
1 Parent(s): 1af2b59

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +483 -0
app.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import shutil
4
+ import re
5
+ import gc
6
+ import time
7
+ from datetime import datetime
8
+ from typing import List, Tuple, Dict, Union, Optional
9
+ from fastapi import FastAPI, UploadFile, File, HTTPException
10
+ from fastapi.responses import FileResponse, JSONResponse
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ import pandas as pd
13
+ import pdfplumber
14
+ import torch
15
+ import matplotlib.pyplot as plt
16
+ from fpdf import FPDF
17
+ import unicodedata
18
+ import uvicorn
19
+
20
+ # === Configuration ===
21
+ persistent_dir = "/data/hf_cache"
22
+ model_cache_dir = os.path.join(persistent_dir, "txagent_models")
23
+ tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
24
+ file_cache_dir = os.path.join(persistent_dir, "cache")
25
+ report_dir = os.path.join(persistent_dir, "reports")
26
+
27
+ for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
28
+ os.makedirs(d, exist_ok=True)
29
+
30
+ os.environ["HF_HOME"] = model_cache_dir
31
+ os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
32
+
33
+ current_dir = os.path.dirname(os.path.abspath(__file__))
34
+ src_path = os.path.abspath(os.path.join(current_dir, "src"))
35
+ sys.path.insert(0, src_path)
36
+
37
+ from txagent.txagent import TxAgent
38
+
39
+ MAX_MODEL_TOKENS = 131072
40
+ MAX_NEW_TOKENS = 4096
41
+ MAX_CHUNK_TOKENS = 8192
42
+ BATCH_SIZE = 1
43
+ PROMPT_OVERHEAD = 300
44
+ SAFE_SLEEP = 0.5
45
+
46
+ app = FastAPI(title="Clinical Patient Support System API",
47
+ description="API for analyzing and summarizing unstructured medical files",
48
+ version="1.0.0")
49
+
50
+ # CORS configuration
51
+ app.add_middleware(
52
+ CORSMiddleware,
53
+ allow_origins=["*"],
54
+ allow_credentials=True,
55
+ allow_methods=["*"],
56
+ allow_headers=["*"],
57
+ )
58
+
59
+ # Initialize agent at startup
60
+ agent = None
61
+
62
+ @app.on_event("startup")
63
+ async def startup_event():
64
+ global agent
65
+ agent = init_agent()
66
+
67
+ def estimate_tokens(text: str) -> int:
68
+ return len(text) // 4 + 1
69
+
70
+ def clean_response(text: str) -> str:
71
+ text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
72
+ text = re.sub(r"\n{3,}", "\n\n", text)
73
+ return text.strip()
74
+
75
+ def remove_duplicate_paragraphs(text: str) -> str:
76
+ paragraphs = text.strip().split("\n\n")
77
+ seen = set()
78
+ unique_paragraphs = []
79
+ for p in paragraphs:
80
+ clean_p = p.strip()
81
+ if clean_p and clean_p not in seen:
82
+ unique_paragraphs.append(clean_p)
83
+ seen.add(clean_p)
84
+ return "\n\n".join(unique_paragraphs)
85
+
86
+ def extract_text_from_excel(path: str) -> str:
87
+ all_text = []
88
+ xls = pd.ExcelFile(path)
89
+ for sheet_name in xls.sheet_names:
90
+ try:
91
+ df = xls.parse(sheet_name).astype(str).fillna("")
92
+ except Exception:
93
+ continue
94
+ for _, row in df.iterrows():
95
+ non_empty = [cell.strip() for cell in row if cell.strip()]
96
+ if len(non_empty) >= 2:
97
+ text_line = " | ".join(non_empty)
98
+ if len(text_line) > 15:
99
+ all_text.append(f"[{sheet_name}] {text_line}")
100
+ return "\n".join(all_text)
101
+
102
+ def extract_text_from_csv(path: str) -> str:
103
+ all_text = []
104
+ try:
105
+ df = pd.read_csv(path).astype(str).fillna("")
106
+ except Exception:
107
+ return ""
108
+ for _, row in df.iterrows():
109
+ non_empty = [cell.strip() for cell in row if cell.strip()]
110
+ if len(non_empty) >= 2:
111
+ text_line = " | ".join(non_empty)
112
+ if len(text_line) > 15:
113
+ all_text.append(text_line)
114
+ return "\n".join(all_text)
115
+
116
+ def extract_text_from_pdf(path: str) -> str:
117
+ import logging
118
+ logging.getLogger("pdfminer").setLevel(logging.ERROR)
119
+ all_text = []
120
+ try:
121
+ with pdfplumber.open(path) as pdf:
122
+ for page in pdf.pages:
123
+ text = page.extract_text()
124
+ if text:
125
+ all_text.append(text.strip())
126
+ except Exception:
127
+ return ""
128
+ return "\n".join(all_text)
129
+
130
+ def extract_text(file_path: str) -> str:
131
+ if file_path.endswith(".xlsx"):
132
+ return extract_text_from_excel(file_path)
133
+ elif file_path.endswith(".csv"):
134
+ return extract_text_from_csv(file_path)
135
+ elif file_path.endswith(".pdf"):
136
+ return extract_text_from_pdf(file_path)
137
+ else:
138
+ return ""
139
+
140
+ def split_text(text: str, max_tokens=MAX_CHUNK_TOKENS) -> List[str]:
141
+ effective_limit = max_tokens - PROMPT_OVERHEAD
142
+ chunks, current, current_tokens = [], [], 0
143
+ for line in text.split("\n"):
144
+ tokens = estimate_tokens(line)
145
+ if current_tokens + tokens > effective_limit:
146
+ if current:
147
+ chunks.append("\n".join(current))
148
+ current, current_tokens = [line], tokens
149
+ else:
150
+ current.append(line)
151
+ current_tokens += tokens
152
+ if current:
153
+ chunks.append("\n".join(current))
154
+ return chunks
155
+
156
+ def batch_chunks(chunks: List[str], batch_size: int = BATCH_SIZE) -> List[List[str]]:
157
+ return [chunks[i:i+batch_size] for i in range(0, len(chunks), batch_size)]
158
+
159
+ def build_prompt(chunk: str) -> str:
160
+ return f"""### Unstructured Clinical Records\n\nAnalyze the clinical notes below and summarize with:\n- Diagnostic Patterns\n- Medication Issues\n- Missed Opportunities\n- Inconsistencies\n- Follow-up Recommendations\n\n---\n\n{chunk}\n\n---\nRespond concisely in bullet points with clinical reasoning."""
161
+
162
+ def init_agent() -> TxAgent:
163
+ tool_path = os.path.join(tool_cache_dir, "new_tool.json")
164
+ if not os.path.exists(tool_path):
165
+ shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
166
+ agent = TxAgent(
167
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
168
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
169
+ tool_files_dict={"new_tool": tool_path},
170
+ force_finish=True,
171
+ enable_checker=True,
172
+ step_rag_num=4,
173
+ seed=100
174
+ )
175
+ agent.init_model()
176
+ return agent
177
+
178
+ def analyze_batches(agent, batches: List[List[str]]) -> List[str]:
179
+ results = []
180
+ for batch in batches:
181
+ prompt = "\n\n".join(build_prompt(chunk) for chunk in batch)
182
+ try:
183
+ batch_response = ""
184
+ for r in agent.run_gradio_chat(
185
+ message=prompt,
186
+ history=[],
187
+ temperature=0.0,
188
+ max_new_tokens=MAX_NEW_TOKENS,
189
+ max_token=MAX_MODEL_TOKENS,
190
+ call_agent=False,
191
+ conversation=[]
192
+ ):
193
+ if isinstance(r, str):
194
+ batch_response += r
195
+ elif isinstance(r, list):
196
+ for m in r:
197
+ if hasattr(m, "content"):
198
+ batch_response += m.content
199
+ elif hasattr(r, "content"):
200
+ batch_response += r.content
201
+ results.append(clean_response(batch_response))
202
+ time.sleep(SAFE_SLEEP)
203
+ except Exception as e:
204
+ results.append(f"❌ Batch failed: {str(e)}")
205
+ time.sleep(SAFE_SLEEP * 2)
206
+ torch.cuda.empty_cache()
207
+ gc.collect()
208
+ return results
209
+
210
+ def generate_final_summary(agent, combined: str) -> str:
211
+ combined = remove_duplicate_paragraphs(combined)
212
+ final_prompt = f"""
213
+ You are an expert clinical summarizer. Analyze the following summaries carefully and generate a **single final concise structured medical report**, avoiding any repetition or redundancy.
214
+ Summaries:
215
+ {combined}
216
+ Respond with:
217
+ - Diagnostic Patterns
218
+ - Medication Issues
219
+ - Missed Opportunities
220
+ - Inconsistencies
221
+ - Follow-up Recommendations
222
+ Avoid repeating the same points multiple times.
223
+ """.strip()
224
+
225
+ final_response = ""
226
+ for r in agent.run_gradio_chat(
227
+ message=final_prompt,
228
+ history=[],
229
+ temperature=0.0,
230
+ max_new_tokens=MAX_NEW_TOKENS,
231
+ max_token=MAX_MODEL_TOKENS,
232
+ call_agent=False,
233
+ conversation=[]
234
+ ):
235
+ if isinstance(r, str):
236
+ final_response += r
237
+ elif isinstance(r, list):
238
+ for m in r:
239
+ if hasattr(m, "content"):
240
+ final_response += m.content
241
+ elif hasattr(r, "content"):
242
+ final_response += r.content
243
+
244
+ final_response = clean_response(final_response)
245
+ final_response = remove_duplicate_paragraphs(final_response)
246
+ return final_response
247
+
248
+ def remove_non_ascii(text):
249
+ return ''.join(c for c in text if ord(c) < 256)
250
+
251
+ def generate_pdf_report_with_charts(summary: str, report_path: str, detailed_batches: List[str] = None):
252
+ chart_dir = os.path.join(os.path.dirname(report_path), "charts")
253
+ os.makedirs(chart_dir, exist_ok=True)
254
+
255
+ # Prepare static data
256
+ categories = ['Diagnostics', 'Medications', 'Missed', 'Inconsistencies', 'Follow-up']
257
+ values = [4, 2, 3, 1, 5]
258
+
259
+ # === Static Charts ===
260
+ chart_paths = []
261
+
262
+ def save_chart(fig_func, filename):
263
+ path = os.path.join(chart_dir, filename)
264
+ fig_func()
265
+ plt.tight_layout()
266
+ plt.savefig(path)
267
+ plt.close()
268
+ chart_paths.append((filename.split('.')[0].replace('_', ' ').title(), path))
269
+
270
+ save_chart(lambda: plt.bar(categories, values), "bar_chart.png")
271
+ save_chart(lambda: plt.pie(values, labels=categories, autopct='%1.1f%%'), "pie_chart.png")
272
+ save_chart(lambda: plt.plot(categories, values, marker='o'), "trend_chart.png")
273
+ save_chart(lambda: plt.barh(categories, values), "horizontal_bar_chart.png")
274
+
275
+ # Radar chart
276
+ import numpy as np
277
+ labels = np.array(categories)
278
+ stats = np.array(values)
279
+ angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
280
+ stats = np.concatenate((stats, [stats[0]]))
281
+ angles += angles[:1]
282
+ fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
283
+ ax.plot(angles, stats, marker='o')
284
+ ax.fill(angles, stats, alpha=0.25)
285
+ ax.set_yticklabels([])
286
+ ax.set_xticks(angles[:-1])
287
+ ax.set_xticklabels(labels)
288
+ ax.set_title('Radar Chart: Clinical Focus')
289
+ radar_path = os.path.join(chart_dir, "radar_chart.png")
290
+ plt.tight_layout()
291
+ plt.savefig(radar_path)
292
+ plt.close()
293
+ chart_paths.append(("Radar Chart: Clinical Focus", radar_path))
294
+
295
+ # === Dynamic Chart: Drug Frequency ===
296
+ drug_counter = {}
297
+ if detailed_batches:
298
+ for batch in detailed_batches:
299
+ lines = batch.split("\n")
300
+ for line in lines:
301
+ match = re.search(r"(?i)medication[s]?:\s*(.+)", line)
302
+ if match:
303
+ items = re.split(r"[,;]", match.group(1))
304
+ for item in items:
305
+ drug = item.strip().title()
306
+ if len(drug) > 2:
307
+ drug_counter[drug] = drug_counter.get(drug, 0) + 1
308
+
309
+ if drug_counter:
310
+ drugs, freqs = zip(*sorted(drug_counter.items(), key=lambda x: x[1], reverse=True)[:10])
311
+ plt.figure(figsize=(6, 4))
312
+ plt.bar(drugs, freqs)
313
+ plt.xticks(rotation=45, ha='right')
314
+ plt.title('Top Medications Frequency')
315
+ drug_chart_path = os.path.join(chart_dir, "drug_frequency_chart.png")
316
+ plt.tight_layout()
317
+ plt.savefig(drug_chart_path)
318
+ plt.close()
319
+ chart_paths.append(("Top Medications Frequency", drug_chart_path))
320
+
321
+ # === PDF ===
322
+ pdf_path = report_path.replace('.md', '.pdf')
323
+ pdf = FPDF()
324
+ pdf.set_auto_page_break(auto=True, margin=20)
325
+
326
+ def add_section_title(pdf, title):
327
+ pdf.set_fill_color(230, 230, 230)
328
+ pdf.set_font("Arial", 'B', 14)
329
+ pdf.cell(0, 10, remove_non_ascii(title), ln=True, fill=True)
330
+ pdf.ln(3)
331
+
332
+ def add_footer(pdf):
333
+ pdf.set_y(-15)
334
+ pdf.set_font('Arial', 'I', 8)
335
+ pdf.set_text_color(150, 150, 150)
336
+ pdf.cell(0, 10, f"Page {pdf.page_no()}", align='C')
337
+
338
+ # Title Page
339
+ pdf.add_page()
340
+ pdf.set_font("Arial", 'B', 26)
341
+ pdf.set_text_color(0, 70, 140)
342
+ pdf.cell(0, 20, remove_non_ascii("Final Medical Report"), ln=True, align='C')
343
+ pdf.set_text_color(0, 0, 0)
344
+ pdf.set_font("Arial", '', 13)
345
+ pdf.cell(0, 10, datetime.now().strftime("Generated on %B %d, %Y at %H:%M"), ln=True, align='C')
346
+ pdf.ln(15)
347
+ pdf.set_font("Arial", '', 11)
348
+ pdf.set_fill_color(245, 245, 245)
349
+ pdf.multi_cell(0, 9, remove_non_ascii(
350
+ "This report contains a professional summary of clinical observations, potential inconsistencies, and follow-up recommendations based on the uploaded medical document."
351
+ ), border=1, fill=True, align="J")
352
+ add_footer(pdf)
353
+
354
+ # Final Summary
355
+ pdf.add_page()
356
+ add_section_title(pdf, "Final Summary")
357
+ pdf.set_font("Arial", '', 11)
358
+ for line in summary.split("\n"):
359
+ clean_line = remove_non_ascii(line.strip())
360
+ if clean_line:
361
+ pdf.multi_cell(0, 8, txt=clean_line)
362
+ add_footer(pdf)
363
+
364
+ # Charts Section
365
+ pdf.add_page()
366
+ add_section_title(pdf, "Statistical Overview")
367
+ for title, path in chart_paths:
368
+ pdf.set_font("Arial", 'B', 12)
369
+ pdf.cell(0, 9, remove_non_ascii(title), ln=True)
370
+ pdf.image(path, w=170)
371
+ pdf.ln(6)
372
+ add_footer(pdf)
373
+
374
+ # Detailed Tool Outputs
375
+ if detailed_batches:
376
+ pdf.add_page()
377
+ add_section_title(pdf, "Detailed Tool Insights")
378
+ for idx, detail in enumerate(detailed_batches):
379
+ pdf.set_font("Arial", 'B', 12)
380
+ pdf.cell(0, 9, remove_non_ascii(f"Tool Output #{idx + 1}"), ln=True)
381
+ pdf.set_font("Arial", '', 11)
382
+ for line in remove_non_ascii(detail).split("\n"):
383
+ pdf.multi_cell(0, 8, txt=line.strip())
384
+ pdf.ln(3)
385
+ add_footer(pdf)
386
+
387
+ pdf.output(pdf_path)
388
+ return pdf_path
389
+
390
+ @app.post("/analyze", summary="Analyze medical document", response_description="Returns analysis results")
391
+ async def analyze_document(file: UploadFile = File(...)):
392
+ """
393
+ Analyze a medical document (PDF, Excel, or CSV) and return a structured analysis.
394
+
395
+ Args:
396
+ file: The medical document to analyze (PDF, Excel, or CSV format)
397
+
398
+ Returns:
399
+ JSONResponse: Contains analysis results and report download path
400
+ """
401
+ start_time = time.time()
402
+
403
+ try:
404
+ # Save the uploaded file temporarily
405
+ temp_path = os.path.join(file_cache_dir, file.filename)
406
+ with open(temp_path, "wb") as f:
407
+ f.write(await file.read())
408
+
409
+ extracted = extract_text(temp_path)
410
+ if not extracted:
411
+ raise HTTPException(status_code=400, detail="Could not extract text from the file")
412
+
413
+ chunks = split_text(extracted)
414
+ batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
415
+ batch_results = analyze_batches(agent, batches)
416
+ all_tool_outputs = batch_results.copy()
417
+ valid = [res for res in batch_results if not res.startswith("❌")]
418
+
419
+ if not valid:
420
+ raise HTTPException(status_code=400, detail="No valid analysis results were generated")
421
+
422
+ summary = generate_final_summary(agent, "\n\n".join(valid))
423
+
424
+ # Generate report files
425
+ report_filename = f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
426
+ report_path = os.path.join(report_dir, f"{report_filename}.md")
427
+ with open(report_path, 'w', encoding='utf-8') as f:
428
+ f.write(f"# Final Medical Report\n\n{summary}")
429
+
430
+ pdf_path = generate_pdf_report_with_charts(summary, report_path, detailed_batches=all_tool_outputs)
431
+
432
+ end_time = time.time()
433
+ elapsed_time = end_time - start_time
434
+
435
+ # Clean up temp file
436
+ os.remove(temp_path)
437
+
438
+ return JSONResponse({
439
+ "status": "success",
440
+ "summary": summary,
441
+ "report_path": f"/reports/{os.path.basename(pdf_path)}",
442
+ "processing_time": f"{elapsed_time:.2f} seconds",
443
+ "detailed_outputs": all_tool_outputs
444
+ })
445
+
446
+ except Exception as e:
447
+ raise HTTPException(status_code=500, detail=str(e))
448
+
449
+ @app.get("/reports/{filename}", response_class=FileResponse)
450
+ async def download_report(filename: str):
451
+ """
452
+ Download a generated report PDF file.
453
+
454
+ Args:
455
+ filename: The name of the report file to download
456
+
457
+ Returns:
458
+ FileResponse: The PDF file for download
459
+ """
460
+ file_path = os.path.join(report_dir, filename)
461
+ if not os.path.exists(file_path):
462
+ raise HTTPException(status_code=404, detail="Report not found")
463
+ return FileResponse(file_path, media_type='application/pdf', filename=filename)
464
+
465
+ @app.get("/status")
466
+ async def service_status():
467
+ """
468
+ Check the service status and version information.
469
+
470
+ Returns:
471
+ JSONResponse: Service status information
472
+ """
473
+ return JSONResponse({
474
+ "status": "running",
475
+ "version": "1.0.0",
476
+ "model": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
477
+ "rag_model": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
478
+ "max_tokens": MAX_MODEL_TOKENS,
479
+ "supported_file_types": [".pdf", ".xlsx", ".csv"]
480
+ })
481
+
482
+ if __name__ == "__main__":
483
+ uvicorn.run(app, host="0.0.0.0", port=7860)