Ali2206 commited on
Commit
d377221
·
verified ·
1 Parent(s): f9e3082

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -126
app.py CHANGED
@@ -1,24 +1,15 @@
1
- import sys
2
  import os
 
3
  import json
4
  import shutil
5
- import re
6
- import gc
7
- import time
8
- from datetime import datetime
9
- from typing import List, Tuple, Dict, Union, Optional
10
  from fastapi import FastAPI, UploadFile, File, HTTPException
11
- from fastapi.responses import FileResponse, JSONResponse
12
  from fastapi.middleware.cors import CORSMiddleware
13
- import pandas as pd
14
- import pdfplumber
15
  import torch
16
- import matplotlib.pyplot as plt
17
- from fpdf import FPDF
18
- import unicodedata
19
- import uvicorn
20
 
21
- # === Configuration ===
22
  persistent_dir = "/data/hf_cache"
23
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
24
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
@@ -26,34 +17,24 @@ file_cache_dir = os.path.join(persistent_dir, "cache")
26
  report_dir = os.path.join(persistent_dir, "reports")
27
 
28
  # Create directories if they don't exist
29
- for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
30
- os.makedirs(d, exist_ok=True)
 
 
31
 
32
  # Set environment variables
33
  os.environ["HF_HOME"] = model_cache_dir
34
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
35
- os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" # Fix for matplotlib permission issues
36
 
37
  # Set up Python path
38
  current_dir = os.path.dirname(os.path.abspath(__file__))
39
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
40
  sys.path.insert(0, src_path)
41
 
42
- # Import TxAgent after setting up paths
43
- from txagent.txagent import TxAgent
44
-
45
- # Constants
46
- MAX_MODEL_TOKENS = 131072
47
- MAX_NEW_TOKENS = 4096
48
- MAX_CHUNK_TOKENS = 8192
49
- BATCH_SIZE = 1
50
- PROMPT_OVERHEAD = 300
51
- SAFE_SLEEP = 0.5
52
-
53
  # Initialize FastAPI app
54
  app = FastAPI(
55
  title="Clinical Patient Support System API",
56
- description="API for analyzing and summarizing unstructured medical files",
57
  version="1.0.0"
58
  )
59
 
@@ -77,12 +58,13 @@ async def startup_event():
77
  except Exception as e:
78
  raise RuntimeError(f"Failed to initialize agent: {str(e)}")
79
 
80
- def init_agent() -> TxAgent:
81
  """Initialize and return the TxAgent instance."""
82
  tool_path = os.path.join(tool_cache_dir, "new_tool.json")
83
  if not os.path.exists(tool_path):
84
  shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
85
 
 
86
  agent = TxAgent(
87
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
88
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
@@ -90,132 +72,46 @@ def init_agent() -> TxAgent:
90
  force_finish=True,
91
  enable_checker=True,
92
  step_rag_num=4,
93
- seed=100
 
94
  )
95
  agent.init_model()
96
  return agent
97
 
98
- # Utility functions (keep your existing functions but add error handling)
99
- def estimate_tokens(text: str) -> int:
100
- """Estimate the number of tokens in the given text."""
101
- return len(text) // 4 + 1
102
-
103
- def clean_response(text: str) -> str:
104
- """Clean and format the response text."""
105
- if not text:
106
- return ""
107
- text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
108
- text = re.sub(r"\n{3,}", "\n\n", text)
109
- return text.strip()
110
-
111
- def extract_text_from_excel(path: str) -> str:
112
- """Extract text from Excel file."""
113
- try:
114
- all_text = []
115
- xls = pd.ExcelFile(path)
116
- for sheet_name in xls.sheet_names:
117
- try:
118
- df = xls.parse(sheet_name).astype(str).fillna("")
119
- except Exception:
120
- continue
121
- for _, row in df.iterrows():
122
- non_empty = [cell.strip() for cell in row if cell.strip()]
123
- if len(non_empty) >= 2:
124
- text_line = " | ".join(non_empty)
125
- if len(text_line) > 15:
126
- all_text.append(f"[{sheet_name}] {text_line}")
127
- return "\n".join(all_text)
128
- except Exception as e:
129
- raise RuntimeError(f"Failed to extract text from Excel: {str(e)}")
130
-
131
- def extract_text(file_path: str) -> str:
132
- """Extract text from supported file types."""
133
- try:
134
- if file_path.endswith(".xlsx"):
135
- return extract_text_from_excel(file_path)
136
- elif file_path.endswith(".csv"):
137
- df = pd.read_csv(file_path).astype(str).fillna("")
138
- return "\n".join(
139
- " | ".join(cell.strip() for cell in row if cell.strip())
140
- for _, row in df.iterrows()
141
- if len([cell for cell in row if cell.strip()]) >= 2
142
- )
143
- elif file_path.endswith(".pdf"):
144
- with pdfplumber.open(file_path) as pdf:
145
- return "\n".join(page.extract_text() or "" for page in pdf.pages)
146
- else:
147
- return ""
148
- except Exception as e:
149
- raise RuntimeError(f"Failed to extract text from file: {str(e)}")
150
-
151
- # API endpoints
152
  @app.post("/analyze")
153
  async def analyze_document(file: UploadFile = File(...)):
154
  """Analyze a medical document and return results."""
155
- start_time = time.time()
156
-
157
  try:
158
  # Save the uploaded file temporarily
159
  temp_path = os.path.join(file_cache_dir, file.filename)
160
  with open(temp_path, "wb") as f:
161
  f.write(await file.read())
162
 
163
- extracted = extract_text(temp_path)
164
- if not extracted:
165
- raise HTTPException(status_code=400, detail="Could not extract text from the file")
166
-
167
- chunks = split_text(extracted)
168
- batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
169
- batch_results = analyze_batches(agent, batches)
170
 
171
- valid_results = [res for res in batch_results if not res.startswith("❌")]
172
- if not valid_results:
173
- raise HTTPException(status_code=400, detail="No valid analysis results were generated")
174
-
175
- final_summary = generate_final_summary(agent, "\n\n".join(valid_results))
176
-
177
- # Generate report files
178
- report_filename = f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
179
- report_path = os.path.join(report_dir, f"{report_filename}.md")
180
- with open(report_path, 'w', encoding='utf-8') as f:
181
- f.write(f"# Final Medical Report\n\n{final_summary}")
182
-
183
- pdf_path = generate_pdf_report_with_charts(final_summary, report_path, detailed_batches=batch_results)
184
-
185
- # Clean up temp file
186
  os.remove(temp_path)
187
 
188
  return JSONResponse({
189
  "status": "success",
190
- "summary": final_summary,
191
- "report_path": f"/reports/{os.path.basename(pdf_path)}",
192
- "processing_time": f"{time.time() - start_time:.2f} seconds",
193
- "detailed_outputs": batch_results
194
  })
195
 
196
- except HTTPException:
197
- raise
198
  except Exception as e:
199
  raise HTTPException(status_code=500, detail=str(e))
200
 
201
- @app.get("/reports/{filename}")
202
- async def download_report(filename: str):
203
- """Download a generated report."""
204
- file_path = os.path.join(report_dir, filename)
205
- if not os.path.exists(file_path):
206
- raise HTTPException(status_code=404, detail="Report not found")
207
- return FileResponse(file_path, media_type='application/pdf', filename=filename)
208
-
209
  @app.get("/status")
210
  async def service_status():
211
  """Check service status."""
212
  return {
213
  "status": "running",
214
  "version": "1.0.0",
215
- "model": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
216
- "max_tokens": MAX_MODEL_TOKENS,
217
- "supported_file_types": [".pdf", ".xlsx", ".csv"]
218
  }
219
 
220
  if __name__ == "__main__":
 
221
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
1
  import os
2
+ import sys
3
  import json
4
  import shutil
 
 
 
 
 
5
  from fastapi import FastAPI, UploadFile, File, HTTPException
6
+ from fastapi.responses import JSONResponse, FileResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
+ from typing import List, Dict, Optional
 
9
  import torch
10
+ from datetime import datetime
 
 
 
11
 
12
+ # Configuration
13
  persistent_dir = "/data/hf_cache"
14
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
15
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
 
17
  report_dir = os.path.join(persistent_dir, "reports")
18
 
19
  # Create directories if they don't exist
20
+ os.makedirs(model_cache_dir, exist_ok=True)
21
+ os.makedirs(tool_cache_dir, exist_ok=True)
22
+ os.makedirs(file_cache_dir, exist_ok=True)
23
+ os.makedirs(report_dir, exist_ok=True)
24
 
25
  # Set environment variables
26
  os.environ["HF_HOME"] = model_cache_dir
27
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
 
28
 
29
  # Set up Python path
30
  current_dir = os.path.dirname(os.path.abspath(__file__))
31
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
32
  sys.path.insert(0, src_path)
33
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Initialize FastAPI app
35
  app = FastAPI(
36
  title="Clinical Patient Support System API",
37
+ description="API for analyzing medical documents",
38
  version="1.0.0"
39
  )
40
 
 
58
  except Exception as e:
59
  raise RuntimeError(f"Failed to initialize agent: {str(e)}")
60
 
61
+ def init_agent():
62
  """Initialize and return the TxAgent instance."""
63
  tool_path = os.path.join(tool_cache_dir, "new_tool.json")
64
  if not os.path.exists(tool_path):
65
  shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
66
 
67
+ from txagent.txagent import TxAgent
68
  agent = TxAgent(
69
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
70
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
72
  force_finish=True,
73
  enable_checker=True,
74
  step_rag_num=4,
75
+ seed=100,
76
+ use_vllm=False # Disable vLLM for Hugging Face Spaces
77
  )
78
  agent.init_model()
79
  return agent
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  @app.post("/analyze")
82
  async def analyze_document(file: UploadFile = File(...)):
83
  """Analyze a medical document and return results."""
 
 
84
  try:
85
  # Save the uploaded file temporarily
86
  temp_path = os.path.join(file_cache_dir, file.filename)
87
  with open(temp_path, "wb") as f:
88
  f.write(await file.read())
89
 
90
+ # Process the file and generate response
91
+ result = agent.process_document(temp_path)
 
 
 
 
 
92
 
93
+ # Clean up
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  os.remove(temp_path)
95
 
96
  return JSONResponse({
97
  "status": "success",
98
+ "result": result,
99
+ "timestamp": datetime.now().isoformat()
 
 
100
  })
101
 
 
 
102
  except Exception as e:
103
  raise HTTPException(status_code=500, detail=str(e))
104
 
 
 
 
 
 
 
 
 
105
  @app.get("/status")
106
  async def service_status():
107
  """Check service status."""
108
  return {
109
  "status": "running",
110
  "version": "1.0.0",
111
+ "model": agent.model_name if agent else "not loaded",
112
+ "device": str(agent.device) if agent else "unknown"
 
113
  }
114
 
115
  if __name__ == "__main__":
116
+ import uvicorn
117
  uvicorn.run(app, host="0.0.0.0", port=7860)