Ali2206 commited on
Commit
60e4c3d
·
verified ·
1 Parent(s): f3089ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -58
app.py CHANGED
@@ -3,42 +3,38 @@ import sys
3
  import json
4
  import re
5
  import logging
 
 
6
  from fastapi import FastAPI, HTTPException, UploadFile, File
7
  from fastapi.responses import JSONResponse
8
  from fastapi.middleware.cors import CORSMiddleware
9
- from typing import List, Dict, Optional
10
- from datetime import datetime
11
  from pydantic import BaseModel
12
  import markdown
13
  import PyPDF2
14
 
15
- # Configure logging
16
  logging.basicConfig(
17
  level=logging.INFO,
18
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
19
  )
20
  logger = logging.getLogger("TxAgentAPI")
21
 
22
- # Add src directory to Python path
23
  current_dir = os.path.dirname(os.path.abspath(__file__))
24
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
25
  sys.path.insert(0, src_path)
26
 
27
- # Import TxAgent after setting up path
28
  try:
29
  from txagent.txagent import TxAgent
30
  except ImportError as e:
31
  logger.error(f"Failed to import TxAgent: {str(e)}")
32
  raise
33
 
34
- # Initialize FastAPI app
35
- app = FastAPI(
36
- title="TxAgent API",
37
- description="API for TxAgent medical document analysis",
38
- version="2.0.0"
39
- )
40
 
41
- # CORS configuration
42
  app.add_middleware(
43
  CORSMiddleware,
44
  allow_origins=["*"],
@@ -47,50 +43,48 @@ app.add_middleware(
47
  allow_headers=["*"],
48
  )
49
 
50
- # Request models
51
  class ChatRequest(BaseModel):
52
  message: str
53
  temperature: float = 0.7
54
  max_new_tokens: int = 512
55
  history: Optional[List[Dict]] = None
56
- format: Optional[str] = "clean" # Options: raw, clean, structured, html
57
 
58
- # Response cleaning functions
59
  def clean_text_response(text: str) -> str:
60
- """Basic text cleaning"""
61
  text = re.sub(r'\n\s*\n', '\n\n', text)
62
  text = re.sub(r'[ ]+', ' ', text)
63
  text = text.replace("**", "").replace("__", "")
64
  return text.strip()
65
 
66
  def structure_medical_response(text: str) -> Dict:
67
- """Structure medical content into categories"""
68
- result = {"overview": "", "symptoms": [], "types": {}, "notes": ""}
69
- overview_end = text.find("Type 1 Diabetes:")
70
- result["overview"] = clean_text_response(text[:overview_end])
71
- type_sections = re.split(r'(Type \d Diabetes:)', text[overview_end:])
72
- current_type = None
73
- for section in type_sections:
74
- if section.startswith("Type "):
75
- current_type = section.replace(":", "").strip().lower()
76
- result["types"][current_type] = []
77
- elif current_type:
78
- points = [p.strip() for p in section.split('\n') if p.strip()]
79
- result["types"][current_type] = points
80
- return result
81
-
82
- # Initialize agent at startup
 
83
  agent = None
84
 
85
  @app.on_event("startup")
86
  async def startup_event():
87
  global agent
88
  try:
89
- logger.info("Initializing TxAgent...")
90
  agent = TxAgent(
91
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
92
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
93
- tool_files_dict={},
94
  enable_finish=True,
95
  enable_rag=False,
96
  force_finish=True,
@@ -98,17 +92,19 @@ async def startup_event():
98
  step_rag_num=4,
99
  seed=100
100
  )
 
 
 
 
 
101
  agent.init_model()
102
  logger.info("TxAgent initialized successfully")
103
  except Exception as e:
104
- logger.error(f"Failed to initialize agent: {str(e)}")
105
- raise RuntimeError(f"Failed to initialize agent: {str(e)}")
106
 
107
  @app.post("/chat")
108
  async def chat_endpoint(request: ChatRequest):
109
- """Handle chat conversations with formatting options"""
110
  try:
111
- logger.info(f"Chat request received (format: {request.format})")
112
  raw_response = agent.chat(
113
  message=request.message,
114
  history=request.history,
@@ -121,11 +117,10 @@ async def chat_endpoint(request: ChatRequest):
121
  "structured": structure_medical_response(raw_response),
122
  "html": markdown.markdown(raw_response)
123
  }
124
- response_content = formatted_response.get(request.format, formatted_response["clean"])
125
  return JSONResponse({
126
  "status": "success",
127
  "format": request.format,
128
- "response": response_content,
129
  "timestamp": datetime.now().isoformat(),
130
  "available_formats": list(formatted_response.keys())
131
  })
@@ -135,36 +130,43 @@ async def chat_endpoint(request: ChatRequest):
135
 
136
  @app.post("/upload")
137
  async def upload_file(file: UploadFile = File(...)):
138
- """Handle file uploads and process with TxAgent"""
139
  try:
140
  logger.info(f"File upload received: {file.filename}")
141
  content = ""
142
- if file.filename.endswith('.pdf'):
143
  pdf_reader = PyPDF2.PdfReader(file.file)
144
  for page in pdf_reader.pages:
145
  content += page.extract_text() or ""
146
  else:
147
  content = await file.read()
148
- content = content.decode('utf-8', errors='ignore')
149
-
150
- message = f"Analyze the following medical document content:\n\n{content[:10000]}"
151
- raw_response = agent.chat(
152
- message=message,
153
- history=[],
154
- temperature=0.7,
155
- max_new_tokens=512
156
- )
 
 
 
 
 
 
 
 
 
157
  formatted_response = {
158
  "raw": raw_response,
159
  "clean": clean_text_response(raw_response),
160
  "structured": structure_medical_response(raw_response),
161
  "html": markdown.markdown(raw_response)
162
  }
163
- response_content = formatted_response["clean"]
164
  return JSONResponse({
165
  "status": "success",
166
  "format": "clean",
167
- "response": response_content,
168
  "timestamp": datetime.now().isoformat(),
169
  "available_formats": list(formatted_response.keys())
170
  })
@@ -175,12 +177,10 @@ async def upload_file(file: UploadFile = File(...)):
175
  file.file.close()
176
 
177
  @app.get("/status")
178
- async def service_status():
179
- """Check service status"""
180
  return {
181
  "status": "running",
182
- "version": "2.0.0",
183
  "model": agent.model_name if agent else "not loaded",
184
- "formats_available": ["raw", "clean", "structured", "html"],
185
  "timestamp": datetime.now().isoformat()
186
- }
 
3
  import json
4
  import re
5
  import logging
6
+ from datetime import datetime
7
+ from typing import List, Dict, Optional
8
  from fastapi import FastAPI, HTTPException, UploadFile, File
9
  from fastapi.responses import JSONResponse
10
  from fastapi.middleware.cors import CORSMiddleware
 
 
11
  from pydantic import BaseModel
12
  import markdown
13
  import PyPDF2
14
 
15
+ # Setup logging
16
  logging.basicConfig(
17
  level=logging.INFO,
18
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
19
  )
20
  logger = logging.getLogger("TxAgentAPI")
21
 
22
+ # Adjust sys path
23
  current_dir = os.path.dirname(os.path.abspath(__file__))
24
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
25
  sys.path.insert(0, src_path)
26
 
27
+ # Import TxAgent
28
  try:
29
  from txagent.txagent import TxAgent
30
  except ImportError as e:
31
  logger.error(f"Failed to import TxAgent: {str(e)}")
32
  raise
33
 
34
+ # Init FastAPI
35
+ app = FastAPI(title="TxAgent API", version="2.1.0")
 
 
 
 
36
 
37
+ # CORS
38
  app.add_middleware(
39
  CORSMiddleware,
40
  allow_origins=["*"],
 
43
  allow_headers=["*"],
44
  )
45
 
46
+ # Request schema
47
  class ChatRequest(BaseModel):
48
  message: str
49
  temperature: float = 0.7
50
  max_new_tokens: int = 512
51
  history: Optional[List[Dict]] = None
52
+ format: Optional[str] = "clean"
53
 
54
+ # Response formatting
55
  def clean_text_response(text: str) -> str:
 
56
  text = re.sub(r'\n\s*\n', '\n\n', text)
57
  text = re.sub(r'[ ]+', ' ', text)
58
  text = text.replace("**", "").replace("__", "")
59
  return text.strip()
60
 
61
  def structure_medical_response(text: str) -> Dict:
62
+ return {
63
+ "summary": extract_section(text, "Summary"),
64
+ "risks": extract_section(text, "Risks or Red Flags"),
65
+ "missed_issues": extract_section(text, "What the doctor might have missed"),
66
+ "recommendations": extract_section(text, "Suggested Clinical Actions")
67
+ }
68
+
69
+ def extract_section(text: str, heading: str) -> str:
70
+ try:
71
+ pattern = rf"{heading}:\n(.*?)(?=\n\w|\Z)"
72
+ match = re.search(pattern, text, re.DOTALL)
73
+ return clean_text_response(match.group(1)) if match else ""
74
+ except Exception as e:
75
+ logger.error(f"Section extraction failed: {e}")
76
+ return ""
77
+
78
+ # Agent init
79
  agent = None
80
 
81
  @app.on_event("startup")
82
  async def startup_event():
83
  global agent
84
  try:
 
85
  agent = TxAgent(
86
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
87
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
88
  enable_finish=True,
89
  enable_rag=False,
90
  force_finish=True,
 
92
  step_rag_num=4,
93
  seed=100
94
  )
95
+ agent.chat_prompt = (
96
+ "You are a clinical decision support assistant for doctors. "
97
+ "You analyze patient documents, detect medical issues, identify missed diagnoses, "
98
+ "and provide treatment suggestions with rationale in concise, readable language."
99
+ )
100
  agent.init_model()
101
  logger.info("TxAgent initialized successfully")
102
  except Exception as e:
103
+ logger.error(f"Startup error: {str(e)}")
 
104
 
105
  @app.post("/chat")
106
  async def chat_endpoint(request: ChatRequest):
 
107
  try:
 
108
  raw_response = agent.chat(
109
  message=request.message,
110
  history=request.history,
 
117
  "structured": structure_medical_response(raw_response),
118
  "html": markdown.markdown(raw_response)
119
  }
 
120
  return JSONResponse({
121
  "status": "success",
122
  "format": request.format,
123
+ "response": formatted_response.get(request.format, formatted_response["clean"]),
124
  "timestamp": datetime.now().isoformat(),
125
  "available_formats": list(formatted_response.keys())
126
  })
 
130
 
131
  @app.post("/upload")
132
  async def upload_file(file: UploadFile = File(...)):
 
133
  try:
134
  logger.info(f"File upload received: {file.filename}")
135
  content = ""
136
+ if file.filename.endswith(".pdf"):
137
  pdf_reader = PyPDF2.PdfReader(file.file)
138
  for page in pdf_reader.pages:
139
  content += page.extract_text() or ""
140
  else:
141
  content = await file.read()
142
+ content = content.decode("utf-8", errors="ignore")
143
+
144
+ message = f"""
145
+ You are a clinical decision support AI assisting physicians.
146
+
147
+ Given the following patient report, do the following:
148
+ 1. Summarize the patient's main conditions and history.
149
+ 2. Identify any potential clinical risks or red flags.
150
+ 3. Highlight any important diagnoses or treatments the doctor might have missed.
151
+ 4. Suggest next clinical steps, treatments, or referrals (if applicable).
152
+ 5. Flag anything that could pose an urgent risk (e.g., suicide risk, untreated critical conditions).
153
+
154
+ Patient Document:
155
+ -----------------
156
+ {content[:10000]}
157
+ """
158
+
159
+ raw_response = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024)
160
  formatted_response = {
161
  "raw": raw_response,
162
  "clean": clean_text_response(raw_response),
163
  "structured": structure_medical_response(raw_response),
164
  "html": markdown.markdown(raw_response)
165
  }
 
166
  return JSONResponse({
167
  "status": "success",
168
  "format": "clean",
169
+ "response": formatted_response["clean"],
170
  "timestamp": datetime.now().isoformat(),
171
  "available_formats": list(formatted_response.keys())
172
  })
 
177
  file.file.close()
178
 
179
  @app.get("/status")
180
+ async def status():
 
181
  return {
182
  "status": "running",
183
+ "version": "2.1.0",
184
  "model": agent.model_name if agent else "not loaded",
 
185
  "timestamp": datetime.now().isoformat()
186
+ }