Ali2206 commited on
Commit
ab172ce
·
verified ·
1 Parent(s): 7e0e71d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -125
app.py CHANGED
@@ -1,48 +1,42 @@
1
- # app.py
2
-
3
  import os
4
  import sys
5
  import json
6
- import re
7
  import logging
8
- import asyncio
9
  from datetime import datetime
10
  from typing import List, Dict, Optional
11
 
12
- from fastapi import FastAPI, HTTPException, UploadFile, File
13
  from fastapi.responses import JSONResponse, StreamingResponse
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from pydantic import BaseModel
16
- import markdown
17
- import PyPDF2
18
-
19
- # Logging setup
20
- logging.basicConfig(level=logging.INFO)
21
- logger = logging.getLogger("TxAgentAPI")
22
 
23
- # Path setup
24
- current_dir = os.path.dirname(os.path.abspath(__file__))
25
- sys.path.insert(0, os.path.join(current_dir, "src"))
26
 
27
- # TxAgent import
28
  from txagent.txagent import TxAgent
29
 
30
- # MongoDB collections (shared URI via Hugging Face secrets)
31
- from db.mongo import patients_collection, results_collection
 
 
 
 
32
 
33
  # FastAPI app
34
  app = FastAPI(title="TxAgent API", version="2.1.0")
35
 
36
- # CORS config
37
  app.add_middleware(
38
  CORSMiddleware,
39
- allow_origins=["*"],
40
- allow_credentials=True,
41
- allow_methods=["*"],
42
- allow_headers=["*"],
43
  )
44
 
45
- # Pydantic schema
46
  class ChatRequest(BaseModel):
47
  message: str
48
  temperature: float = 0.7
@@ -50,18 +44,25 @@ class ChatRequest(BaseModel):
50
  history: Optional[List[Dict]] = None
51
  format: Optional[str] = "clean"
52
 
53
- # Utils
54
- def clean_text(text: str) -> str:
 
 
 
 
 
 
55
  text = re.sub(r'\n\s*\n', '\n\n', text)
56
  text = re.sub(r'[ ]+', ' ', text)
57
- return text.strip().replace("**", "").replace("__", "")
58
 
59
  def extract_section(text: str, heading: str) -> str:
60
  try:
61
- pattern = rf"{heading}:\n(.*?)(?=\n[A-Z]|\Z)"
62
  match = re.search(pattern, text, re.DOTALL)
63
- return clean_text(match.group(1)) if match else ""
64
- except:
 
65
  return ""
66
 
67
  def structure_medical_response(text: str) -> Dict:
@@ -69,136 +70,120 @@ def structure_medical_response(text: str) -> Dict:
69
  "summary": extract_section(text, "Summary"),
70
  "risks": extract_section(text, "Risks or Red Flags"),
71
  "missed_issues": extract_section(text, "What the doctor might have missed"),
72
- "recommendations": extract_section(text, "Suggested Clinical Actions"),
73
  }
74
 
75
- # Global agent
76
- agent = None
 
 
 
77
 
78
- # Background logic
79
- async def analyze_and_store_result(patient: dict):
80
  try:
81
- content = json.dumps(patient, indent=2)[:10000]
82
  message = (
83
- "You are a clinical AI assistant.\n\n"
84
- "Analyze this patient's record and:\n"
85
- "1. Summarize conditions and history.\n"
86
- "2. Identify red flags.\n"
87
- "3. Detect missed issues.\n"
88
- "4. Suggest clinical actions.\n\n"
89
- f"Patient Data:\n{content}"
90
  )
91
 
92
  raw = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024)
93
  structured = structure_medical_response(raw)
94
 
95
- await results_collection.update_one(
 
 
 
 
 
 
96
  {"patient_id": patient.get("fhir_id")},
97
- {
98
- "$set": {
99
- "patient_id": patient.get("fhir_id"),
100
- "full_name": patient.get("full_name"),
101
- "raw": raw,
102
- "structured": structured,
103
- "analyzed_at": datetime.utcnow()
104
- }
105
- },
106
  upsert=True
107
  )
108
- logger.info(f"Stored analysis for {patient.get('fhir_id')}")
109
  except Exception as e:
110
  logger.error(f"Error analyzing patient: {e}")
111
 
112
- async def analyze_existing_patients():
113
- try:
114
- patients = await patients_collection.find({}).to_list(length=None)
115
- for patient in patients:
116
- await analyze_and_store_result(patient)
117
- await asyncio.sleep(0.3)
118
- except Exception as e:
119
- logger.error(f"Batch analysis failed: {e}")
120
-
121
- async def watch_new_patients():
122
- try:
123
- logger.info("Watching for new patient inserts...")
124
- pipeline = [{'$match': {'operationType': 'insert'}}]
125
- async with patients_collection.watch(pipeline) as stream:
126
- async for change in stream:
127
- patient = change["fullDocument"]
128
- await analyze_and_store_result(patient)
129
- except Exception as e:
130
- logger.error(f"Change stream error: {e}")
131
 
 
132
  @app.on_event("startup")
133
  async def startup_event():
134
- global agent
 
 
135
  agent = TxAgent(
136
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
137
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
138
  enable_finish=True,
139
- enable_checker=True,
140
  force_finish=True,
 
 
 
141
  )
142
  agent.chat_prompt = (
143
- "You are a clinical decision support AI helping doctors review patient records and suggest care plans."
144
  )
145
  agent.init_model()
146
- logger.info("TxAgent loaded")
147
- asyncio.create_task(analyze_existing_patients())
148
- asyncio.create_task(watch_new_patients())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  @app.post("/chat-stream")
151
- async def chat_stream(request: ChatRequest):
152
- async def stream():
153
  try:
154
- msgs = [{"role": "system", "content": agent.chat_prompt}]
155
  if request.history:
156
- msgs += request.history
157
- msgs.append({"role": "user", "content": request.message})
158
-
159
- input_ids = agent.tokenizer.apply_chat_template(msgs, add_generation_prompt=True, return_tensors="pt").to(agent.device)
160
- output = agent.model.generate(input_ids, do_sample=True, temperature=request.temperature, max_new_tokens=request.max_new_tokens)
161
- text = agent.tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
162
-
 
 
 
 
 
 
 
 
 
 
163
  for chunk in text.split():
164
  yield chunk + " "
165
  await asyncio.sleep(0.05)
166
  except Exception as e:
167
- yield f"\n⚠️ Error: {e}"
168
-
169
- return StreamingResponse(stream(), media_type="text/plain")
170
 
171
- @app.post("/upload")
172
- async def upload_file(file: UploadFile = File(...)):
173
- try:
174
- logger.info(f"Uploaded file: {file.filename}")
175
- text = ""
176
- if file.filename.endswith(".pdf"):
177
- pdf = PyPDF2.PdfReader(file.file)
178
- text = "\n".join(p.extract_text() for p in pdf.pages if p.extract_text())
179
- else:
180
- content = await file.read()
181
- text = content.decode("utf-8", errors="ignore")
182
-
183
- prompt = (
184
- "You are a clinical support AI. Analyze the following:\n"
185
- f"{text[:10000]}"
186
- )
187
- raw = agent.chat(message=prompt, history=[], temperature=0.7)
188
- return {
189
- "status": "success",
190
- "response": clean_text(raw),
191
- "structured": structure_medical_response(raw),
192
- "timestamp": datetime.now().isoformat()
193
- }
194
- except Exception as e:
195
- logger.error(f"Upload error: {e}")
196
- raise HTTPException(status_code=500, detail=str(e))
197
-
198
- @app.get("/status")
199
- async def status():
200
- return {
201
- "status": "running",
202
- "model": agent.model_name,
203
- "timestamp": datetime.now().isoformat()
204
- }
 
 
 
1
  import os
2
  import sys
3
  import json
 
4
  import logging
5
+ import re
6
  from datetime import datetime
7
  from typing import List, Dict, Optional
8
 
9
+ from fastapi import FastAPI, HTTPException
10
  from fastapi.responses import JSONResponse, StreamingResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from pydantic import BaseModel
13
+ from pymongo import MongoClient
14
+ from bson import ObjectId
15
+ import asyncio
 
 
 
16
 
17
+ # Adjust sys path
18
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
 
19
 
20
+ # TxAgent
21
  from txagent.txagent import TxAgent
22
 
23
+ # MongoDB
24
+ from db.mongo import get_mongo_client
25
+
26
+ # Setup logging
27
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
28
+ logger = logging.getLogger("TxAgentAPI")
29
 
30
  # FastAPI app
31
  app = FastAPI(title="TxAgent API", version="2.1.0")
32
 
 
33
  app.add_middleware(
34
  CORSMiddleware,
35
+ allow_origins=["*"], allow_credentials=True,
36
+ allow_methods=["*"], allow_headers=["*"]
 
 
37
  )
38
 
39
+ # Models
40
  class ChatRequest(BaseModel):
41
  message: str
42
  temperature: float = 0.7
 
44
  history: Optional[List[Dict]] = None
45
  format: Optional[str] = "clean"
46
 
47
+ # Globals
48
+ agent = None
49
+ mongo_client = None
50
+ patients_collection = None
51
+ analysis_collection = None
52
+
53
+ # Helpers
54
+ def clean_text_response(text: str) -> str:
55
  text = re.sub(r'\n\s*\n', '\n\n', text)
56
  text = re.sub(r'[ ]+', ' ', text)
57
+ return text.replace("**", "").replace("__", "").strip()
58
 
59
  def extract_section(text: str, heading: str) -> str:
60
  try:
61
+ pattern = rf"{heading}:\n(.*?)(?=\n\w|\Z)"
62
  match = re.search(pattern, text, re.DOTALL)
63
+ return clean_text_response(match.group(1)) if match else ""
64
+ except Exception as e:
65
+ logger.error(f"Section extraction failed: {e}")
66
  return ""
67
 
68
  def structure_medical_response(text: str) -> Dict:
 
70
  "summary": extract_section(text, "Summary"),
71
  "risks": extract_section(text, "Risks or Red Flags"),
72
  "missed_issues": extract_section(text, "What the doctor might have missed"),
73
+ "recommendations": extract_section(text, "Suggested Clinical Actions")
74
  }
75
 
76
+ def serialize_patient(patient: dict) -> dict:
77
+ patient_copy = patient.copy()
78
+ if "_id" in patient_copy:
79
+ patient_copy["_id"] = str(patient_copy["_id"])
80
+ return patient_copy
81
 
82
+ async def analyze_patient(patient: dict):
 
83
  try:
84
+ doc = json.dumps(serialize_patient(patient), indent=2)
85
  message = (
86
+ "You are a clinical decision support AI.\n\n"
87
+ "Given the patient document below:\n"
88
+ "1. Summarize their medical history.\n"
89
+ "2. Identify risks or red flags.\n"
90
+ "3. Highlight missed diagnoses or treatments.\n"
91
+ "4. Suggest next clinical steps.\n"
92
+ f"\nPatient Document:\n{'-'*40}\n{doc[:10000]}"
93
  )
94
 
95
  raw = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024)
96
  structured = structure_medical_response(raw)
97
 
98
+ analysis_doc = {
99
+ "patient_id": patient.get("fhir_id"),
100
+ "timestamp": datetime.utcnow(),
101
+ "summary": structured,
102
+ "raw": raw
103
+ }
104
+ await analysis_collection.update_one(
105
  {"patient_id": patient.get("fhir_id")},
106
+ {"$set": analysis_doc},
 
 
 
 
 
 
 
 
107
  upsert=True
108
  )
109
+ logger.info(f"✔️ Analysis stored for patient {patient.get('fhir_id')}")
110
  except Exception as e:
111
  logger.error(f"Error analyzing patient: {e}")
112
 
113
+ async def analyze_all_patients():
114
+ patients = await patients_collection.find({}).to_list(length=None)
115
+ for patient in patients:
116
+ await analyze_patient(patient)
117
+ await asyncio.sleep(0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ # Startup logic
120
  @app.on_event("startup")
121
  async def startup_event():
122
+ global agent, mongo_client, patients_collection, analysis_collection
123
+
124
+ # Init agent
125
  agent = TxAgent(
126
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
127
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
128
  enable_finish=True,
129
+ enable_rag=False,
130
  force_finish=True,
131
+ enable_checker=True,
132
+ step_rag_num=4,
133
+ seed=42
134
  )
135
  agent.chat_prompt = (
136
+ "You are a clinical assistant AI. Analyze the patient's data and provide clear clinical recommendations."
137
  )
138
  agent.init_model()
139
+ logger.info("TxAgent initialized")
140
+
141
+ # MongoDB
142
+ mongo_client = get_mongo_client()
143
+ db = mongo_client.get_default_database()
144
+ patients_collection = db.get_collection("patients")
145
+ analysis_collection = db.get_collection("patient_analysis_results")
146
+
147
+ logger.info("📡 Connected to MongoDB")
148
+ asyncio.create_task(analyze_all_patients())
149
+
150
+ # Endpoints
151
+ @app.get("/status")
152
+ async def status():
153
+ return {
154
+ "status": "running",
155
+ "version": "2.1.0",
156
+ "timestamp": datetime.utcnow().isoformat()
157
+ }
158
 
159
  @app.post("/chat-stream")
160
+ async def chat_stream_endpoint(request: ChatRequest):
161
+ async def token_stream():
162
  try:
163
+ conversation = [{"role": "system", "content": agent.chat_prompt}]
164
  if request.history:
165
+ conversation.extend(request.history)
166
+ conversation.append({"role": "user", "content": request.message})
167
+
168
+ input_ids = agent.tokenizer.apply_chat_template(
169
+ conversation, add_generation_prompt=True, return_tensors="pt"
170
+ ).to(agent.device)
171
+
172
+ output = agent.model.generate(
173
+ input_ids,
174
+ do_sample=True,
175
+ temperature=request.temperature,
176
+ max_new_tokens=request.max_new_tokens,
177
+ pad_token_id=agent.tokenizer.eos_token_id,
178
+ return_dict_in_generate=True
179
+ )
180
+
181
+ text = agent.tokenizer.decode(output["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True)
182
  for chunk in text.split():
183
  yield chunk + " "
184
  await asyncio.sleep(0.05)
185
  except Exception as e:
186
+ logger.error(f"Streaming error: {e}")
187
+ yield f"⚠️ Error: {e}"
 
188
 
189
+ return StreamingResponse(token_stream(), media_type="text/plain")