Ali2206 commited on
Commit
94662cd
·
verified ·
1 Parent(s): 8dff938

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -16
app.py CHANGED
@@ -4,7 +4,8 @@ import json
4
  import logging
5
  import re
6
  from datetime import datetime
7
- from typing import List, Dict, Optional
 
8
 
9
  from fastapi import FastAPI, HTTPException
10
  from fastapi.responses import StreamingResponse
@@ -20,7 +21,7 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(
20
  logger = logging.getLogger("TxAgentAPI")
21
 
22
  # App
23
- app = FastAPI(title="TxAgent API", version="2.1.0")
24
 
25
  app.add_middleware(
26
  CORSMiddleware,
@@ -36,10 +37,19 @@ class ChatRequest(BaseModel):
36
  history: Optional[List[Dict]] = None
37
  format: Optional[str] = "clean"
38
 
 
 
 
 
 
 
 
 
39
  # Globals
40
  agent = None
41
  patients_collection = None
42
  analysis_collection = None
 
43
 
44
  # Helpers
45
  def clean_text_response(text: str) -> str:
@@ -59,25 +69,21 @@ def extract_section(text: str, heading: str) -> str:
59
  def structure_medical_response(text: str) -> Dict:
60
  """Improved version that handles both markdown and plain text formats"""
61
  def extract_improved(text: str, heading: str) -> str:
62
- # Try multiple patterns to match different heading formats
63
  patterns = [
64
- rf"{re.escape(heading)}:\s*\n(.*?)(?=\n\s*\n|\Z)", # Heading followed by content until double newline
65
- rf"\*\*{re.escape(heading)}\*\*:\s*\n(.*?)(?=\n\s*\n|\Z)", # Markdown bold heading
66
- rf"{re.escape(heading)}[\s\-]+(.*?)(?=\n\s*\n|\Z)", # Heading with dashes
67
- rf"\n{re.escape(heading)}\s*\n(.*?)(?=\n\s*\n|\Z)" # Heading on its own line
68
  ]
69
 
70
  for pattern in patterns:
71
  match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
72
  if match:
73
  content = match.group(1).strip()
74
- # Clean up any remaining markdown or special characters
75
  content = re.sub(r'^\s*[\-\*]\s*', '', content, flags=re.MULTILINE)
76
  return content
77
-
78
  return ""
79
 
80
- # Normalize the text first
81
  text = text.replace('**', '').replace('__', '')
82
 
83
  return {
@@ -90,6 +96,72 @@ def structure_medical_response(text: str) -> Dict:
90
  "recommendations": extract_improved(text, "Suggest Next Clinical Steps") or
91
  extract_improved(text, "Suggested Clinical Actions")
92
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def serialize_patient(patient: dict) -> dict:
94
  patient_copy = patient.copy()
95
  if "_id" in patient_copy:
@@ -101,13 +173,13 @@ async def analyze_patient(patient: dict):
101
  serialized = serialize_patient(patient)
102
  doc = json.dumps(serialized, indent=2)
103
  logger.info(f"🧾 Analyzing patient: {serialized.get('fhir_id')}")
104
- logger.debug(f"🧠 Data passed to TxAgent:\n{doc[:1000]}")
105
 
 
106
  message = (
107
  "You are a clinical decision support AI.\n\n"
108
  "Given the patient document below:\n"
109
  "1. Summarize the patient's medical history.\n"
110
- "2. Identify risks or red flags.\n"
111
  "3. Highlight missed diagnoses or treatments.\n"
112
  "4. Suggest next clinical steps.\n"
113
  f"\nPatient Document:\n{'-'*40}\n{doc[:10000]}"
@@ -115,19 +187,36 @@ async def analyze_patient(patient: dict):
115
 
116
  raw = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024)
117
  structured = structure_medical_response(raw)
118
-
 
 
 
 
 
 
 
 
 
119
  analysis_doc = {
120
  "patient_id": serialized.get("fhir_id"),
121
  "timestamp": datetime.utcnow(),
122
  "summary": structured,
 
123
  "raw": raw
124
  }
 
125
  await analysis_collection.update_one(
126
  {"patient_id": serialized.get("fhir_id")},
127
  {"$set": analysis_doc},
128
  upsert=True
129
  )
 
 
 
 
 
130
  logger.info(f"✅ Stored analysis for patient {serialized.get('fhir_id')}")
 
131
  except Exception as e:
132
  logger.error(f"Error analyzing patient: {e}")
133
 
@@ -139,7 +228,7 @@ async def analyze_all_patients():
139
 
140
  @app.on_event("startup")
141
  async def startup_event():
142
- global agent, patients_collection, analysis_collection
143
 
144
  agent = TxAgent(
145
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
@@ -160,6 +249,7 @@ async def startup_event():
160
  db = get_mongo_client()["cps_db"]
161
  patients_collection = db["patients"]
162
  analysis_collection = db["patient_analysis_results"]
 
163
  logger.info("📡 Connected to MongoDB")
164
 
165
  asyncio.create_task(analyze_all_patients())
@@ -168,7 +258,8 @@ async def startup_event():
168
  async def status():
169
  return {
170
  "status": "running",
171
- "timestamp": datetime.utcnow().isoformat()
 
172
  }
173
 
174
  @app.post("/chat-stream")
@@ -201,4 +292,4 @@ async def chat_stream_endpoint(request: ChatRequest):
201
  logger.error(f"Streaming error: {e}")
202
  yield f"⚠️ Error: {e}"
203
 
204
- return StreamingResponse(token_stream(), media_type="text/plain")
 
4
  import logging
5
  import re
6
  from datetime import datetime
7
+ from typing import List, Dict, Optional, Tuple
8
+ from enum import Enum
9
 
10
  from fastapi import FastAPI, HTTPException
11
  from fastapi.responses import StreamingResponse
 
21
  logger = logging.getLogger("TxAgentAPI")
22
 
23
  # App
24
+ app = FastAPI(title="TxAgent API", version="2.2.0") # Version bump for new features
25
 
26
  app.add_middleware(
27
  CORSMiddleware,
 
37
  history: Optional[List[Dict]] = None
38
  format: Optional[str] = "clean"
39
 
40
+ # Enums
41
+ class RiskLevel(str, Enum):
42
+ NONE = "none"
43
+ LOW = "low"
44
+ MODERATE = "moderate"
45
+ HIGH = "high"
46
+ SEVERE = "severe"
47
+
48
  # Globals
49
  agent = None
50
  patients_collection = None
51
  analysis_collection = None
52
+ alerts_collection = None
53
 
54
  # Helpers
55
  def clean_text_response(text: str) -> str:
 
69
  def structure_medical_response(text: str) -> Dict:
70
  """Improved version that handles both markdown and plain text formats"""
71
  def extract_improved(text: str, heading: str) -> str:
 
72
  patterns = [
73
+ rf"{re.escape(heading)}:\s*\n(.*?)(?=\n\s*\n|\Z)",
74
+ rf"\*\*{re.escape(heading)}\*\*:\s*\n(.*?)(?=\n\s*\n|\Z)",
75
+ rf"{re.escape(heading)}[\s\-]+(.*?)(?=\n\s*\n|\Z)",
76
+ rf"\n{re.escape(heading)}\s*\n(.*?)(?=\n\s*\n|\Z)"
77
  ]
78
 
79
  for pattern in patterns:
80
  match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
81
  if match:
82
  content = match.group(1).strip()
 
83
  content = re.sub(r'^\s*[\-\*]\s*', '', content, flags=re.MULTILINE)
84
  return content
 
85
  return ""
86
 
 
87
  text = text.replace('**', '').replace('__', '')
88
 
89
  return {
 
96
  "recommendations": extract_improved(text, "Suggest Next Clinical Steps") or
97
  extract_improved(text, "Suggested Clinical Actions")
98
  }
99
+
100
+ def detect_suicide_risk(text: str) -> Tuple[RiskLevel, float, List[str]]:
101
+ """Analyze text for suicide risk factors and return assessment"""
102
+ suicide_keywords = [
103
+ 'suicide', 'suicidal', 'kill myself', 'end my life',
104
+ 'want to die', 'self-harm', 'self harm', 'hopeless',
105
+ 'no reason to live', 'plan to die'
106
+ ]
107
+
108
+ # Check for explicit mentions
109
+ explicit_mentions = [kw for kw in suicide_keywords if kw in text.lower()]
110
+
111
+ if not explicit_mentions:
112
+ return RiskLevel.NONE, 0.0, []
113
+
114
+ # If found, ask AI for detailed assessment
115
+ assessment_prompt = (
116
+ "Assess the suicide risk level based on this text. "
117
+ "Consider frequency, specificity, and severity of statements. "
118
+ "Respond with JSON format: {\"risk_level\": \"low/moderate/high/severe\", "
119
+ "\"risk_score\": 0-1, \"factors\": [\"list of risk factors\"]}\n\n"
120
+ f"Text to assess:\n{text}"
121
+ )
122
+
123
+ try:
124
+ response = agent.chat(
125
+ message=assessment_prompt,
126
+ history=[],
127
+ temperature=0.2, # Lower temp for more deterministic responses
128
+ max_new_tokens=256
129
+ )
130
+
131
+ # Extract JSON from response
132
+ json_match = re.search(r'\{.*\}', response, re.DOTALL)
133
+ if json_match:
134
+ assessment = json.loads(json_match.group())
135
+ return (
136
+ RiskLevel(assessment.get("risk_level", "none").lower()),
137
+ float(assessment.get("risk_score", 0)),
138
+ assessment.get("factors", [])
139
+ )
140
+ except Exception as e:
141
+ logger.error(f"Error in suicide risk assessment: {e}")
142
+
143
+ # Fallback if JSON parsing fails
144
+ risk_score = min(0.1 * len(explicit_mentions), 0.9) # Cap at 0.9 for fallback
145
+ if risk_score > 0.7:
146
+ return RiskLevel.HIGH, risk_score, explicit_mentions
147
+ elif risk_score > 0.4:
148
+ return RiskLevel.MODERATE, risk_score, explicit_mentions
149
+ return RiskLevel.LOW, risk_score, explicit_mentions
150
+
151
+ async def create_alert(patient_id: str, risk_data: dict):
152
+ """Create an alert document in the database"""
153
+ alert_doc = {
154
+ "patient_id": patient_id,
155
+ "type": "suicide_risk",
156
+ "level": risk_data["level"],
157
+ "score": risk_data["score"],
158
+ "factors": risk_data["factors"],
159
+ "timestamp": datetime.utcnow(),
160
+ "acknowledged": False
161
+ }
162
+ await alerts_collection.insert_one(alert_doc)
163
+ logger.warning(f"⚠️ Created suicide risk alert for patient {patient_id}")
164
+
165
  def serialize_patient(patient: dict) -> dict:
166
  patient_copy = patient.copy()
167
  if "_id" in patient_copy:
 
173
  serialized = serialize_patient(patient)
174
  doc = json.dumps(serialized, indent=2)
175
  logger.info(f"🧾 Analyzing patient: {serialized.get('fhir_id')}")
 
176
 
177
+ # Main clinical analysis
178
  message = (
179
  "You are a clinical decision support AI.\n\n"
180
  "Given the patient document below:\n"
181
  "1. Summarize the patient's medical history.\n"
182
+ "2. Identify risks or red flags (including mental health and suicide risk).\n"
183
  "3. Highlight missed diagnoses or treatments.\n"
184
  "4. Suggest next clinical steps.\n"
185
  f"\nPatient Document:\n{'-'*40}\n{doc[:10000]}"
 
187
 
188
  raw = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024)
189
  structured = structure_medical_response(raw)
190
+
191
+ # Suicide risk assessment
192
+ risk_level, risk_score, risk_factors = detect_suicide_risk(raw)
193
+ suicide_risk = {
194
+ "level": risk_level.value,
195
+ "score": risk_score,
196
+ "factors": risk_factors
197
+ }
198
+
199
+ # Store analysis
200
  analysis_doc = {
201
  "patient_id": serialized.get("fhir_id"),
202
  "timestamp": datetime.utcnow(),
203
  "summary": structured,
204
+ "suicide_risk": suicide_risk,
205
  "raw": raw
206
  }
207
+
208
  await analysis_collection.update_one(
209
  {"patient_id": serialized.get("fhir_id")},
210
  {"$set": analysis_doc},
211
  upsert=True
212
  )
213
+
214
+ # Create alert if risk is above threshold
215
+ if risk_level in [RiskLevel.MODERATE, RiskLevel.HIGH, RiskLevel.SEVERE]:
216
+ await create_alert(serialized.get("fhir_id"), suicide_risk)
217
+
218
  logger.info(f"✅ Stored analysis for patient {serialized.get('fhir_id')}")
219
+
220
  except Exception as e:
221
  logger.error(f"Error analyzing patient: {e}")
222
 
 
228
 
229
  @app.on_event("startup")
230
  async def startup_event():
231
+ global agent, patients_collection, analysis_collection, alerts_collection
232
 
233
  agent = TxAgent(
234
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
 
249
  db = get_mongo_client()["cps_db"]
250
  patients_collection = db["patients"]
251
  analysis_collection = db["patient_analysis_results"]
252
+ alerts_collection = db["clinical_alerts"] # New collection for alerts
253
  logger.info("📡 Connected to MongoDB")
254
 
255
  asyncio.create_task(analyze_all_patients())
 
258
  async def status():
259
  return {
260
  "status": "running",
261
+ "timestamp": datetime.utcnow().isoformat(),
262
+ "version": "2.2.0"
263
  }
264
 
265
  @app.post("/chat-stream")
 
292
  logger.error(f"Streaming error: {e}")
293
  yield f"⚠️ Error: {e}"
294
 
295
+ return StreamingResponse(token_stream(), media_type="text/plain")