Ali2206 commited on
Commit
f0898a3
·
verified ·
1 Parent(s): 4edd370

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -34
app.py CHANGED
@@ -7,27 +7,19 @@ from datetime import datetime
7
  from typing import List, Dict, Optional
8
 
9
  from fastapi import FastAPI, HTTPException
10
- from fastapi.responses import JSONResponse, StreamingResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from pydantic import BaseModel
13
- from pymongo import MongoClient
14
- from bson import ObjectId
15
  import asyncio
16
 
17
- # Adjust sys path
18
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
19
-
20
- # TxAgent
21
  from txagent.txagent import TxAgent
22
-
23
- # MongoDB
24
  from db.mongo import get_mongo_client
25
 
26
- # Setup logging
27
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
28
  logger = logging.getLogger("TxAgentAPI")
29
 
30
- # FastAPI app
31
  app = FastAPI(title="TxAgent API", version="2.1.0")
32
 
33
  app.add_middleware(
@@ -36,7 +28,7 @@ app.add_middleware(
36
  allow_methods=["*"], allow_headers=["*"]
37
  )
38
 
39
- # Models
40
  class ChatRequest(BaseModel):
41
  message: str
42
  temperature: float = 0.7
@@ -46,7 +38,6 @@ class ChatRequest(BaseModel):
46
 
47
  # Globals
48
  agent = None
49
- mongo_client = None
50
  patients_collection = None
51
  analysis_collection = None
52
 
@@ -58,15 +49,13 @@ def clean_text_response(text: str) -> str:
58
 
59
  def extract_section(text: str, heading: str) -> str:
60
  try:
61
- # Accept formats like "Heading:" or "**Heading:**"
62
- pattern = rf"(\*\*)?{re.escape(heading)}(\*\*)?:\s*\n(.*?)(?=\n\s*\*?\*?\w|$)"
63
  match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
64
- return clean_text_response(match.group(3)) if match else ""
65
  except Exception as e:
66
  logger.error(f"Section extraction failed for heading '{heading}': {e}")
67
  return ""
68
 
69
-
70
  def structure_medical_response(text: str) -> Dict:
71
  return {
72
  "summary": extract_section(text, "Summarize the patient's medical history"),
@@ -75,7 +64,6 @@ def structure_medical_response(text: str) -> Dict:
75
  "recommendations": extract_section(text, "Suggested Clinical Actions")
76
  }
77
 
78
-
79
  def serialize_patient(patient: dict) -> dict:
80
  patient_copy = patient.copy()
81
  if "_id" in patient_copy:
@@ -84,11 +72,15 @@ def serialize_patient(patient: dict) -> dict:
84
 
85
  async def analyze_patient(patient: dict):
86
  try:
87
- doc = json.dumps(serialize_patient(patient), indent=2)
 
 
 
 
88
  message = (
89
  "You are a clinical decision support AI.\n\n"
90
  "Given the patient document below:\n"
91
- "1. Summarize their medical history.\n"
92
  "2. Identify risks or red flags.\n"
93
  "3. Highlight missed diagnoses or treatments.\n"
94
  "4. Suggest next clinical steps.\n"
@@ -99,17 +91,17 @@ async def analyze_patient(patient: dict):
99
  structured = structure_medical_response(raw)
100
 
101
  analysis_doc = {
102
- "patient_id": patient.get("fhir_id"),
103
  "timestamp": datetime.utcnow(),
104
  "summary": structured,
105
  "raw": raw
106
  }
107
  await analysis_collection.update_one(
108
- {"patient_id": patient.get("fhir_id")},
109
  {"$set": analysis_doc},
110
  upsert=True
111
  )
112
- logger.info(f"✔️ Analysis stored for patient {patient.get('fhir_id')}")
113
  except Exception as e:
114
  logger.error(f"Error analyzing patient: {e}")
115
 
@@ -119,12 +111,10 @@ async def analyze_all_patients():
119
  await analyze_patient(patient)
120
  await asyncio.sleep(0.1)
121
 
122
- # Startup logic
123
  @app.on_event("startup")
124
  async def startup_event():
125
- global agent, mongo_client, patients_collection, analysis_collection
126
 
127
- # Init agent
128
  agent = TxAgent(
129
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
130
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
@@ -141,21 +131,17 @@ async def startup_event():
141
  agent.init_model()
142
  logger.info("✅ TxAgent initialized")
143
 
144
- # MongoDB
145
- mongo_client = get_mongo_client()
146
- db = mongo_client["cps_db"]
147
- patients_collection = db.get_collection("patients")
148
- analysis_collection = db.get_collection("patient_analysis_results")
149
-
150
  logger.info("📡 Connected to MongoDB")
 
151
  asyncio.create_task(analyze_all_patients())
152
 
153
- # Endpoints
154
  @app.get("/status")
155
  async def status():
156
  return {
157
  "status": "running",
158
- "version": "2.1.0",
159
  "timestamp": datetime.utcnow().isoformat()
160
  }
161
 
 
7
  from typing import List, Dict, Optional
8
 
9
  from fastapi import FastAPI, HTTPException
10
+ from fastapi.responses import StreamingResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from pydantic import BaseModel
 
 
13
  import asyncio
14
 
 
 
 
 
15
  from txagent.txagent import TxAgent
 
 
16
  from db.mongo import get_mongo_client
17
 
18
+ # Logging
19
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
20
  logger = logging.getLogger("TxAgentAPI")
21
 
22
+ # App
23
  app = FastAPI(title="TxAgent API", version="2.1.0")
24
 
25
  app.add_middleware(
 
28
  allow_methods=["*"], allow_headers=["*"]
29
  )
30
 
31
+ # Pydantic
32
  class ChatRequest(BaseModel):
33
  message: str
34
  temperature: float = 0.7
 
38
 
39
  # Globals
40
  agent = None
 
41
  patients_collection = None
42
  analysis_collection = None
43
 
 
49
 
50
  def extract_section(text: str, heading: str) -> str:
51
  try:
52
+ pattern = rf"{re.escape(heading)}:\s*\n(.*?)(?=\n[A-Z][^\n]*:|\Z)"
 
53
  match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
54
+ return match.group(1).strip() if match else ""
55
  except Exception as e:
56
  logger.error(f"Section extraction failed for heading '{heading}': {e}")
57
  return ""
58
 
 
59
  def structure_medical_response(text: str) -> Dict:
60
  return {
61
  "summary": extract_section(text, "Summarize the patient's medical history"),
 
64
  "recommendations": extract_section(text, "Suggested Clinical Actions")
65
  }
66
 
 
67
  def serialize_patient(patient: dict) -> dict:
68
  patient_copy = patient.copy()
69
  if "_id" in patient_copy:
 
72
 
73
  async def analyze_patient(patient: dict):
74
  try:
75
+ serialized = serialize_patient(patient)
76
+ doc = json.dumps(serialized, indent=2)
77
+ logger.info(f"🧾 Analyzing patient: {serialized.get('fhir_id')}")
78
+ logger.debug(f"🧠 Data passed to TxAgent:\n{doc[:1000]}")
79
+
80
  message = (
81
  "You are a clinical decision support AI.\n\n"
82
  "Given the patient document below:\n"
83
+ "1. Summarize the patient's medical history.\n"
84
  "2. Identify risks or red flags.\n"
85
  "3. Highlight missed diagnoses or treatments.\n"
86
  "4. Suggest next clinical steps.\n"
 
91
  structured = structure_medical_response(raw)
92
 
93
  analysis_doc = {
94
+ "patient_id": serialized.get("fhir_id"),
95
  "timestamp": datetime.utcnow(),
96
  "summary": structured,
97
  "raw": raw
98
  }
99
  await analysis_collection.update_one(
100
+ {"patient_id": serialized.get("fhir_id")},
101
  {"$set": analysis_doc},
102
  upsert=True
103
  )
104
+ logger.info(f" Stored analysis for patient {serialized.get('fhir_id')}")
105
  except Exception as e:
106
  logger.error(f"Error analyzing patient: {e}")
107
 
 
111
  await analyze_patient(patient)
112
  await asyncio.sleep(0.1)
113
 
 
114
  @app.on_event("startup")
115
  async def startup_event():
116
+ global agent, patients_collection, analysis_collection
117
 
 
118
  agent = TxAgent(
119
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
120
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
131
  agent.init_model()
132
  logger.info("✅ TxAgent initialized")
133
 
134
+ db = get_mongo_client()["cps_db"]
135
+ patients_collection = db["patients"]
136
+ analysis_collection = db["patient_analysis_results"]
 
 
 
137
  logger.info("📡 Connected to MongoDB")
138
+
139
  asyncio.create_task(analyze_all_patients())
140
 
 
141
  @app.get("/status")
142
  async def status():
143
  return {
144
  "status": "running",
 
145
  "timestamp": datetime.utcnow().isoformat()
146
  }
147