Ali2206 commited on
Commit
dfff005
·
verified ·
1 Parent(s): 85d52f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -122
app.py CHANGED
@@ -1,41 +1,39 @@
 
 
1
  import os
2
  import sys
3
  import json
4
  import re
5
  import logging
 
6
  from datetime import datetime
7
  from typing import List, Dict, Optional
 
8
  from fastapi import FastAPI, HTTPException, UploadFile, File
9
  from fastapi.responses import JSONResponse, StreamingResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from pydantic import BaseModel
12
  import markdown
13
  import PyPDF2
14
- import asyncio
15
 
16
- # Setup logging
17
- logging.basicConfig(
18
- level=logging.INFO,
19
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
20
- )
21
  logger = logging.getLogger("TxAgentAPI")
22
 
23
- # Adjust sys path
24
  current_dir = os.path.dirname(os.path.abspath(__file__))
25
- src_path = os.path.abspath(os.path.join(current_dir, "src"))
26
- sys.path.insert(0, src_path)
27
 
28
- # Import TxAgent
29
- try:
30
- from txagent.txagent import TxAgent
31
- except ImportError as e:
32
- logger.error(f"Failed to import TxAgent: {str(e)}")
33
- raise
34
 
35
- # Init FastAPI
 
 
 
36
  app = FastAPI(title="TxAgent API", version="2.1.0")
37
 
38
- # CORS
39
  app.add_middleware(
40
  CORSMiddleware,
41
  allow_origins=["*"],
@@ -44,7 +42,7 @@ app.add_middleware(
44
  allow_headers=["*"],
45
  )
46
 
47
- # Request schema
48
  class ChatRequest(BaseModel):
49
  message: str
50
  temperature: float = 0.7
@@ -52,149 +50,155 @@ class ChatRequest(BaseModel):
52
  history: Optional[List[Dict]] = None
53
  format: Optional[str] = "clean"
54
 
55
- # Response formatting
56
- def clean_text_response(text: str) -> str:
57
  text = re.sub(r'\n\s*\n', '\n\n', text)
58
  text = re.sub(r'[ ]+', ' ', text)
59
- text = text.replace("**", "").replace("__", "")
60
- return text.strip()
 
 
 
 
 
 
 
61
 
62
  def structure_medical_response(text: str) -> Dict:
63
  return {
64
  "summary": extract_section(text, "Summary"),
65
  "risks": extract_section(text, "Risks or Red Flags"),
66
  "missed_issues": extract_section(text, "What the doctor might have missed"),
67
- "recommendations": extract_section(text, "Suggested Clinical Actions")
68
  }
69
 
70
- def extract_section(text: str, heading: str) -> str:
 
 
 
 
71
  try:
72
- pattern = rf"{heading}:\n(.*?)(?=\n\w|\Z)"
73
- match = re.search(pattern, text, re.DOTALL)
74
- return clean_text_response(match.group(1)) if match else ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  except Exception as e:
76
- logger.error(f"Section extraction failed: {e}")
77
- return ""
78
 
79
- # Agent init
80
- agent = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  @app.on_event("startup")
83
  async def startup_event():
84
  global agent
85
- try:
86
- agent = TxAgent(
87
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
88
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
89
- enable_finish=True,
90
- enable_rag=False,
91
- force_finish=True,
92
- enable_checker=True,
93
- step_rag_num=4,
94
- seed=100
95
- )
96
- agent.chat_prompt = (
97
- "You are a clinical decision support assistant for doctors. "
98
- "You analyze patient documents, detect medical issues, identify missed diagnoses, "
99
- "and provide treatment suggestions with rationale in concise, readable language."
100
- )
101
- agent.init_model()
102
- logger.info("TxAgent initialized successfully")
103
- except Exception as e:
104
- logger.error(f"Startup error: {str(e)}")
105
 
106
  @app.post("/chat-stream")
107
- async def chat_stream_endpoint(request: ChatRequest):
108
- async def token_stream():
109
  try:
110
- conversation = []
111
- conversation.append({"role": "system", "content": agent.chat_prompt})
112
  if request.history:
113
- for msg in request.history:
114
- conversation.append({"role": msg["role"], "content": msg["content"]})
115
- conversation.append({"role": "user", "content": request.message})
116
-
117
- input_ids = agent.tokenizer.apply_chat_template(
118
- conversation,
119
- add_generation_prompt=True,
120
- return_tensors="pt"
121
- ).to(agent.device)
122
-
123
- streamer = agent.model.generate(
124
- input_ids,
125
- do_sample=True,
126
- temperature=request.temperature,
127
- max_new_tokens=request.max_new_tokens,
128
- pad_token_id=agent.tokenizer.eos_token_id,
129
- return_dict_in_generate=True,
130
- output_scores=False
131
- )
132
-
133
- output = agent.tokenizer.decode(streamer["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True)
134
-
135
- for chunk in output.split():
136
  yield chunk + " "
137
  await asyncio.sleep(0.05)
138
-
139
  except Exception as e:
140
- logger.error(f"Streaming chat error: {str(e)}")
141
- yield f"\n⚠️ Error: {str(e)}"
142
 
143
- return StreamingResponse(token_stream(), media_type="text/plain")
144
 
145
  @app.post("/upload")
146
  async def upload_file(file: UploadFile = File(...)):
147
  try:
148
- logger.info(f"File upload received: {file.filename}")
149
- content = ""
150
  if file.filename.endswith(".pdf"):
151
- pdf_reader = PyPDF2.PdfReader(file.file)
152
- for page in pdf_reader.pages:
153
- content += page.extract_text() or ""
154
  else:
155
  content = await file.read()
156
- content = content.decode("utf-8", errors="ignore")
157
-
158
- message = f"""
159
- You are a clinical decision support AI assisting physicians.
160
-
161
- Given the following patient report, do the following:
162
- 1. Summarize the patient's main conditions and history.
163
- 2. Identify any potential clinical risks or red flags.
164
- 3. Highlight any important diagnoses or treatments the doctor might have missed.
165
- 4. Suggest next clinical steps, treatments, or referrals (if applicable).
166
- 5. Flag anything that could pose an urgent risk (e.g., suicide risk, untreated critical conditions).
167
-
168
- Patient Document:
169
- -----------------
170
- {content[:10000]}
171
- """
172
-
173
- raw_response = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024)
174
- formatted_response = {
175
- "raw": raw_response,
176
- "clean": clean_text_response(raw_response),
177
- "structured": structure_medical_response(raw_response),
178
- "html": markdown.markdown(raw_response)
179
- }
180
- return JSONResponse({
181
  "status": "success",
182
- "format": "clean",
183
- "response": formatted_response["clean"],
184
- "timestamp": datetime.now().isoformat(),
185
- "available_formats": list(formatted_response.keys())
186
- })
187
  except Exception as e:
188
- logger.error(f"File upload error: {str(e)}")
189
  raise HTTPException(status_code=500, detail=str(e))
190
- finally:
191
- file.file.close()
192
 
193
  @app.get("/status")
194
  async def status():
195
  return {
196
  "status": "running",
197
- "version": "2.1.0",
198
- "model": agent.model_name if agent else "not loaded",
199
  "timestamp": datetime.now().isoformat()
200
  }
 
1
+ # app.py
2
+
3
  import os
4
  import sys
5
  import json
6
  import re
7
  import logging
8
+ import asyncio
9
  from datetime import datetime
10
  from typing import List, Dict, Optional
11
+
12
  from fastapi import FastAPI, HTTPException, UploadFile, File
13
  from fastapi.responses import JSONResponse, StreamingResponse
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from pydantic import BaseModel
16
  import markdown
17
  import PyPDF2
 
18
 
19
+ # Logging setup
20
+ logging.basicConfig(level=logging.INFO)
 
 
 
21
  logger = logging.getLogger("TxAgentAPI")
22
 
23
+ # Path setup
24
  current_dir = os.path.dirname(os.path.abspath(__file__))
25
+ sys.path.insert(0, os.path.join(current_dir, "src"))
 
26
 
27
+ # TxAgent import
28
+ from txagent.txagent import TxAgent
 
 
 
 
29
 
30
+ # MongoDB collections (shared URI via Hugging Face secrets)
31
+ from db.mongo import patients_collection, results_collection
32
+
33
+ # FastAPI app
34
  app = FastAPI(title="TxAgent API", version="2.1.0")
35
 
36
+ # CORS config
37
  app.add_middleware(
38
  CORSMiddleware,
39
  allow_origins=["*"],
 
42
  allow_headers=["*"],
43
  )
44
 
45
+ # Pydantic schema
46
  class ChatRequest(BaseModel):
47
  message: str
48
  temperature: float = 0.7
 
50
  history: Optional[List[Dict]] = None
51
  format: Optional[str] = "clean"
52
 
53
+ # Utils
54
+ def clean_text(text: str) -> str:
55
  text = re.sub(r'\n\s*\n', '\n\n', text)
56
  text = re.sub(r'[ ]+', ' ', text)
57
+ return text.strip().replace("**", "").replace("__", "")
58
+
59
+ def extract_section(text: str, heading: str) -> str:
60
+ try:
61
+ pattern = rf"{heading}:\n(.*?)(?=\n[A-Z]|\Z)"
62
+ match = re.search(pattern, text, re.DOTALL)
63
+ return clean_text(match.group(1)) if match else ""
64
+ except:
65
+ return ""
66
 
67
  def structure_medical_response(text: str) -> Dict:
68
  return {
69
  "summary": extract_section(text, "Summary"),
70
  "risks": extract_section(text, "Risks or Red Flags"),
71
  "missed_issues": extract_section(text, "What the doctor might have missed"),
72
+ "recommendations": extract_section(text, "Suggested Clinical Actions"),
73
  }
74
 
75
+ # Global agent
76
+ agent = None
77
+
78
+ # Background logic
79
+ async def analyze_and_store_result(patient: dict):
80
  try:
81
+ content = json.dumps(patient, indent=2)[:10000]
82
+ message = (
83
+ "You are a clinical AI assistant.\n\n"
84
+ "Analyze this patient's record and:\n"
85
+ "1. Summarize conditions and history.\n"
86
+ "2. Identify red flags.\n"
87
+ "3. Detect missed issues.\n"
88
+ "4. Suggest clinical actions.\n\n"
89
+ f"Patient Data:\n{content}"
90
+ )
91
+
92
+ raw = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024)
93
+ structured = structure_medical_response(raw)
94
+
95
+ await results_collection.update_one(
96
+ {"patient_id": patient.get("fhir_id")},
97
+ {
98
+ "$set": {
99
+ "patient_id": patient.get("fhir_id"),
100
+ "full_name": patient.get("full_name"),
101
+ "raw": raw,
102
+ "structured": structured,
103
+ "analyzed_at": datetime.utcnow()
104
+ }
105
+ },
106
+ upsert=True
107
+ )
108
+ logger.info(f"Stored analysis for {patient.get('fhir_id')}")
109
  except Exception as e:
110
+ logger.error(f"Error analyzing patient: {e}")
 
111
 
112
+ async def analyze_existing_patients():
113
+ try:
114
+ patients = await patients_collection.find({}).to_list(length=None)
115
+ for patient in patients:
116
+ await analyze_and_store_result(patient)
117
+ await asyncio.sleep(0.3)
118
+ except Exception as e:
119
+ logger.error(f"Batch analysis failed: {e}")
120
+
121
+ async def watch_new_patients():
122
+ try:
123
+ logger.info("Watching for new patient inserts...")
124
+ pipeline = [{'$match': {'operationType': 'insert'}}]
125
+ async with patients_collection.watch(pipeline) as stream:
126
+ async for change in stream:
127
+ patient = change["fullDocument"]
128
+ await analyze_and_store_result(patient)
129
+ except Exception as e:
130
+ logger.error(f"Change stream error: {e}")
131
 
132
  @app.on_event("startup")
133
  async def startup_event():
134
  global agent
135
+ agent = TxAgent(
136
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
137
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
138
+ enable_finish=True,
139
+ enable_checker=True,
140
+ force_finish=True,
141
+ )
142
+ agent.chat_prompt = (
143
+ "You are a clinical decision support AI helping doctors review patient records and suggest care plans."
144
+ )
145
+ agent.init_model()
146
+ logger.info("TxAgent loaded")
147
+ asyncio.create_task(analyze_existing_patients())
148
+ asyncio.create_task(watch_new_patients())
 
 
 
 
 
 
149
 
150
  @app.post("/chat-stream")
151
+ async def chat_stream(request: ChatRequest):
152
+ async def stream():
153
  try:
154
+ msgs = [{"role": "system", "content": agent.chat_prompt}]
 
155
  if request.history:
156
+ msgs += request.history
157
+ msgs.append({"role": "user", "content": request.message})
158
+
159
+ input_ids = agent.tokenizer.apply_chat_template(msgs, add_generation_prompt=True, return_tensors="pt").to(agent.device)
160
+ output = agent.model.generate(input_ids, do_sample=True, temperature=request.temperature, max_new_tokens=request.max_new_tokens)
161
+ text = agent.tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
162
+
163
+ for chunk in text.split():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  yield chunk + " "
165
  await asyncio.sleep(0.05)
 
166
  except Exception as e:
167
+ yield f"\n⚠️ Error: {e}"
 
168
 
169
+ return StreamingResponse(stream(), media_type="text/plain")
170
 
171
  @app.post("/upload")
172
  async def upload_file(file: UploadFile = File(...)):
173
  try:
174
+ logger.info(f"Uploaded file: {file.filename}")
175
+ text = ""
176
  if file.filename.endswith(".pdf"):
177
+ pdf = PyPDF2.PdfReader(file.file)
178
+ text = "\n".join(p.extract_text() for p in pdf.pages if p.extract_text())
 
179
  else:
180
  content = await file.read()
181
+ text = content.decode("utf-8", errors="ignore")
182
+
183
+ prompt = (
184
+ "You are a clinical support AI. Analyze the following:\n"
185
+ f"{text[:10000]}"
186
+ )
187
+ raw = agent.chat(message=prompt, history=[], temperature=0.7)
188
+ return {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  "status": "success",
190
+ "response": clean_text(raw),
191
+ "structured": structure_medical_response(raw),
192
+ "timestamp": datetime.now().isoformat()
193
+ }
 
194
  except Exception as e:
195
+ logger.error(f"Upload error: {e}")
196
  raise HTTPException(status_code=500, detail=str(e))
 
 
197
 
198
  @app.get("/status")
199
  async def status():
200
  return {
201
  "status": "running",
202
+ "model": agent.model_name,
 
203
  "timestamp": datetime.now().isoformat()
204
  }