Ali2206 commited on
Commit
f275c80
·
verified ·
1 Parent(s): 1e0df14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -112
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import sys
3
  import json
@@ -9,9 +10,11 @@ import base64
9
  from datetime import datetime
10
  from typing import List, Dict, Optional, Tuple
11
  from enum import Enum
12
- from fastapi import FastAPI, HTTPException, UploadFile, File, Query, Form
13
  from fastapi.responses import StreamingResponse, JSONResponse
14
  from fastapi.middleware.cors import CORSMiddleware
 
 
15
  from pydantic import BaseModel
16
  import asyncio
17
  from bson import ObjectId
@@ -20,17 +23,19 @@ from gtts import gTTS
20
  from pydub import AudioSegment
21
  import PyPDF2
22
  import mimetypes
 
 
23
  from txagent.txagent import TxAgent
24
  from db.mongo import get_mongo_client
25
- from fastapi.encoders import jsonable_encoder
26
- from docx import Document
27
  # Logging
28
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
29
  logger = logging.getLogger("TxAgentAPI")
30
 
31
  # App
32
- app = FastAPI(title="TxAgent API", version="2.6.0") # Updated version for optional patient_id
33
 
 
34
  app.add_middleware(
35
  CORSMiddleware,
36
  allow_origins=["*"],
@@ -39,6 +44,13 @@ app.add_middleware(
39
  allow_headers=["*"]
40
  )
41
 
 
 
 
 
 
 
 
42
  # Pydantic Models
43
  class ChatRequest(BaseModel):
44
  message: str
@@ -55,7 +67,7 @@ class VoiceOutputRequest(BaseModel):
55
  text: str
56
  language: str = "en"
57
  slow: bool = False
58
- return_format: str = "mp3" # mp3 or base64
59
 
60
  # Enums
61
  class RiskLevel(str, Enum):
@@ -71,7 +83,26 @@ patients_collection = None
71
  analysis_collection = None
72
  alerts_collection = None
73
 
74
- # Helpers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def clean_text_response(text: str) -> str:
76
  text = re.sub(r'\n\s*\n', '\n\n', text)
77
  text = re.sub(r'[ ]+', ' ', text)
@@ -87,7 +118,6 @@ def extract_section(text: str, heading: str) -> str:
87
  return ""
88
 
89
  def structure_medical_response(text: str) -> Dict:
90
- """Improved version that handles both markdown and plain text formats"""
91
  def extract_improved(text: str, heading: str) -> str:
92
  patterns = [
93
  rf"{re.escape(heading)}:\s*\n(.*?)(?=\n\s*\n|\Z)",
@@ -95,7 +125,6 @@ def structure_medical_response(text: str) -> Dict:
95
  rf"{re.escape(heading)}[\s\-]+(.*?)(?=\n\s*\n|\Z)",
96
  rf"\n{re.escape(heading)}\s*\n(.*?)(?=\n\s*\n|\Z)"
97
  ]
98
-
99
  for pattern in patterns:
100
  match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
101
  if match:
@@ -103,9 +132,8 @@ def structure_medical_response(text: str) -> Dict:
103
  content = re.sub(r'^\s*[\-\*]\s*', '', content, flags=re.MULTILINE)
104
  return content
105
  return ""
106
-
107
- text = text.replace('**', '').replace('__', '')
108
 
 
109
  return {
110
  "summary": extract_improved(text, "Summary of Patient's Medical History") or
111
  extract_improved(text, "Summarize the patient's medical history"),
@@ -118,15 +146,12 @@ def structure_medical_response(text: str) -> Dict:
118
  }
119
 
120
  def detect_suicide_risk(text: str) -> Tuple[RiskLevel, float, List[str]]:
121
- """Analyze text for suicide risk factors and return assessment"""
122
  suicide_keywords = [
123
  'suicide', 'suicidal', 'kill myself', 'end my life',
124
  'want to die', 'self-harm', 'self harm', 'hopeless',
125
  'no reason to live', 'plan to die'
126
  ]
127
-
128
  explicit_mentions = [kw for kw in suicide_keywords if kw in text.lower()]
129
-
130
  if not explicit_mentions:
131
  return RiskLevel.NONE, 0.0, []
132
 
@@ -145,7 +170,6 @@ def detect_suicide_risk(text: str) -> Tuple[RiskLevel, float, List[str]]:
145
  temperature=0.2,
146
  max_new_tokens=256
147
  )
148
-
149
  json_match = re.search(r'\{.*\}', response, re.DOTALL)
150
  if json_match:
151
  assessment = json.loads(json_match.group())
@@ -165,7 +189,6 @@ def detect_suicide_risk(text: str) -> Tuple[RiskLevel, float, List[str]]:
165
  return RiskLevel.LOW, risk_score, explicit_mentions
166
 
167
  async def create_alert(patient_id: str, risk_data: dict):
168
- """Create an alert document in the database"""
169
  alert_doc = {
170
  "patient_id": patient_id,
171
  "type": "suicide_risk",
@@ -185,16 +208,13 @@ def serialize_patient(patient: dict) -> dict:
185
  return patient_copy
186
 
187
  def compute_patient_data_hash(data: dict) -> str:
188
- """Compute SHA-256 hash of patient data or report."""
189
  serialized = json.dumps(data, sort_keys=True)
190
  return hashlib.sha256(serialized.encode()).hexdigest()
191
 
192
  def compute_file_content_hash(file_content: bytes) -> str:
193
- """Compute SHA-256 hash of file content."""
194
  return hashlib.sha256(file_content).hexdigest()
195
 
196
  def extract_text_from_pdf(pdf_data: bytes) -> str:
197
- """Extract text from a PDF file."""
198
  try:
199
  pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_data))
200
  text = ""
@@ -206,85 +226,70 @@ def extract_text_from_pdf(pdf_data: bytes) -> str:
206
  raise HTTPException(status_code=400, detail="Failed to extract text from PDF")
207
 
208
  async def analyze_patient_report(patient_id: Optional[str], report_content: str, file_type: str, file_content: bytes):
209
- """Analyze a patient report and store results."""
210
- try:
211
- # Use file content hash as identifier if no patient_id is provided
212
- identifier = patient_id if patient_id else compute_file_content_hash(file_content)
213
- report_data = {"identifier": identifier, "content": report_content, "file_type": file_type}
214
- report_hash = compute_patient_data_hash(report_data)
215
- logger.info(f"🧾 Analyzing report for identifier: {identifier}")
216
-
217
- # Check if analysis exists and hash matches
218
- existing_analysis = await analysis_collection.find_one({"identifier": identifier, "report_hash": report_hash})
219
- if existing_analysis:
220
- logger.info(f" No changes in report data for {identifier}, skipping analysis")
221
- return existing_analysis
222
-
223
- # Construct analysis prompt
224
- prompt = (
225
- "You are a clinical decision support AI. Analyze the following patient report:\n"
226
- "1. Summarize the patient's medical history.\n"
227
- "2. Identify risks or red flags (including mental health and suicide risk).\n"
228
- "3. Highlight missed diagnoses or treatments.\n"
229
- "4. Suggest next clinical steps.\n"
230
- f"\nPatient Report ({file_type}):\n{'-'*40}\n{report_content[:10000]}"
231
- )
232
-
233
- # Perform analysis
234
- raw_response = agent.chat(
235
- message=prompt,
236
- history=[],
237
- temperature=0.7,
238
- max_new_tokens=1024
239
- )
240
- structured_response = structure_medical_response(raw_response)
241
 
242
- # Suicide risk assessment
243
- risk_level, risk_score, risk_factors = detect_suicide_risk(raw_response)
244
- suicide_risk = {
245
- "level": risk_level.value,
246
- "score": risk_score,
247
- "factors": risk_factors
248
- }
249
 
250
- # Store analysis
251
- analysis_doc = {
252
- "identifier": identifier,
253
- "patient_id": patient_id, # May be None
254
- "timestamp": datetime.utcnow(),
255
- "summary": structured_response,
256
- "suicide_risk": suicide_risk,
257
- "raw": raw_response,
258
- "report_hash": report_hash,
259
- "file_type": file_type
260
- }
261
 
262
- await analysis_collection.update_one(
263
- {"identifier": identifier, "report_hash": report_hash},
264
- {"$set": analysis_doc},
265
- upsert=True
266
- )
 
 
 
 
 
267
 
268
- # Create alert for high-risk cases only if patient_id is provided
269
- if patient_id and risk_level in [RiskLevel.MODERATE, RiskLevel.HIGH, RiskLevel.SEVERE]:
270
- await create_alert(patient_id, suicide_risk)
 
 
271
 
272
- logger.info(f"✅ Stored analysis for identifier {identifier}")
273
- return analysis_doc
274
 
275
- except Exception as e:
276
- logger.error(f"Error analyzing patient report: {e}")
277
- raise HTTPException(status_code=500, detail="Failed to analyze patient report")
278
 
279
  async def analyze_all_patients():
280
- """Analyze all patients in the database."""
281
  patients = await patients_collection.find({}).to_list(length=None)
282
  for patient in patients:
283
  await analyze_patient(patient)
284
  await asyncio.sleep(0.1)
285
 
286
  async def analyze_patient(patient: dict):
287
- """Analyze patient data (existing logic for patient records)."""
288
  try:
289
  serialized = serialize_patient(patient)
290
  patient_id = serialized.get("fhir_id")
@@ -342,9 +347,7 @@ async def analyze_patient(patient: dict):
342
  logger.error(f"Error analyzing patient: {e}")
343
 
344
  def recognize_speech(audio_data: bytes, language: str = "en-US") -> str:
345
- """Convert speech to text using Google's speech recognition."""
346
  recognizer = sr.Recognizer()
347
-
348
  try:
349
  with io.BytesIO(audio_data) as audio_file:
350
  with sr.AudioFile(audio_file) as source:
@@ -362,7 +365,6 @@ def recognize_speech(audio_data: bytes, language: str = "en-US") -> str:
362
  raise HTTPException(status_code=500, detail="Error processing speech")
363
 
364
  def text_to_speech(text: str, language: str = "en", slow: bool = False) -> bytes:
365
- """Convert text to speech using gTTS and return as MP3 bytes."""
366
  try:
367
  tts = gTTS(text=text, lang=language, slow=slow)
368
  mp3_fp = io.BytesIO()
@@ -394,6 +396,8 @@ async def startup_event():
394
  logger.info("✅ TxAgent initialized")
395
 
396
  db = get_mongo_client()["cps_db"]
 
 
397
  patients_collection = db["patients"]
398
  analysis_collection = db["patient_analysis_results"]
399
  alerts_collection = db["clinical_alerts"]
@@ -401,8 +405,10 @@ async def startup_event():
401
 
402
  asyncio.create_task(analyze_all_patients())
403
 
 
404
  @app.get("/status")
405
- async def status():
 
406
  return {
407
  "status": "running",
408
  "timestamp": datetime.utcnow().isoformat(),
@@ -411,7 +417,11 @@ async def status():
411
  }
412
 
413
  @app.get("/patients/analysis-results")
414
- async def get_patient_analysis_results(name: Optional[str] = Query(None)):
 
 
 
 
415
  try:
416
  query = {}
417
  if name:
@@ -438,7 +448,11 @@ async def get_patient_analysis_results(name: Optional[str] = Query(None)):
438
  raise HTTPException(status_code=500, detail="Failed to retrieve analysis results")
439
 
440
  @app.post("/chat-stream")
441
- async def chat_stream_endpoint(request: ChatRequest):
 
 
 
 
442
  async def token_stream():
443
  try:
444
  conversation = [{"role": "system", "content": agent.chat_prompt}]
@@ -472,9 +486,10 @@ async def chat_stream_endpoint(request: ChatRequest):
472
  @app.post("/voice/transcribe")
473
  async def transcribe_voice(
474
  audio: UploadFile = File(...),
475
- language: str = Query("en-US", description="Language code for speech recognition")
 
476
  ):
477
- """Convert speech to text."""
478
  try:
479
  audio_data = await audio.read()
480
  if not audio.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')):
@@ -490,8 +505,11 @@ async def transcribe_voice(
490
  raise HTTPException(status_code=500, detail="Error processing voice input")
491
 
492
  @app.post("/voice/synthesize")
493
- async def synthesize_voice(request: VoiceOutputRequest):
494
- """Convert text to speech."""
 
 
 
495
  try:
496
  audio_data = text_to_speech(request.text, request.language, request.slow)
497
 
@@ -515,9 +533,10 @@ async def voice_chat_endpoint(
515
  audio: UploadFile = File(...),
516
  language: str = Query("en-US", description="Language code for speech recognition"),
517
  temperature: float = Query(0.7, ge=0.1, le=1.0),
518
- max_new_tokens: int = Query(512, ge=50, le=1024)
 
519
  ):
520
- """Complete voice chat interaction (speech-to-text -> AI -> text-to-speech)."""
521
  try:
522
  audio_data = await audio.read()
523
  user_message = recognize_speech(audio_data, language)
@@ -548,18 +567,11 @@ async def analyze_clinical_report(
548
  file: UploadFile = File(...),
549
  patient_id: Optional[str] = Form(None),
550
  temperature: float = Form(0.5),
551
- max_new_tokens: int = Form(1024)
 
552
  ):
553
- """
554
- Analyze a clinical patient report from an uploaded file.
555
- Parameters:
556
- - file: Uploaded clinical report file (PDF, TXT, DOCX)
557
- - patient_id: Optional patient ID to associate with this report
558
- - temperature: Controls randomness of response (0.1-1.0)
559
- - max_new_tokens: Maximum length of response
560
- """
561
  try:
562
- # Validate file type
563
  content_type = file.content_type
564
  allowed_types = [
565
  'application/pdf',
@@ -573,10 +585,8 @@ async def analyze_clinical_report(
573
  detail=f"Unsupported file type: {content_type}. Supported types: PDF, TXT, DOCX"
574
  )
575
 
576
- # Read file content
577
  file_content = await file.read()
578
 
579
- # Extract text from file
580
  if content_type == 'application/pdf':
581
  text = extract_text_from_pdf(file_content)
582
  elif content_type == 'text/plain':
@@ -587,7 +597,6 @@ async def analyze_clinical_report(
587
  else:
588
  raise HTTPException(status_code=400, detail="Unsupported file type")
589
 
590
- # Clean and validate text
591
  text = clean_text_response(text)
592
  if len(text.strip()) < 50:
593
  raise HTTPException(
@@ -595,7 +604,6 @@ async def analyze_clinical_report(
595
  detail="Extracted text is too short (minimum 50 characters required)"
596
  )
597
 
598
- # Analyze the report
599
  analysis = await analyze_patient_report(
600
  patient_id=patient_id,
601
  report_content=text,
@@ -603,13 +611,11 @@ async def analyze_clinical_report(
603
  file_content=file_content
604
  )
605
 
606
- # Manually convert ObjectId and timestamp if needed
607
  if "_id" in analysis and isinstance(analysis["_id"], ObjectId):
608
  analysis["_id"] = str(analysis["_id"])
609
  if "timestamp" in analysis and isinstance(analysis["timestamp"], datetime):
610
  analysis["timestamp"] = analysis["timestamp"].isoformat()
611
 
612
- # Return response using jsonable_encoder
613
  return JSONResponse(content=jsonable_encoder({
614
  "status": "success",
615
  "analysis": analysis,
@@ -627,7 +633,6 @@ async def analyze_clinical_report(
627
  detail=f"Failed to analyze report: {str(e)}"
628
  )
629
 
630
-
631
  if __name__ == "__main__":
632
  import uvicorn
633
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ # app.py (in TxAgent-API)
2
  import os
3
  import sys
4
  import json
 
10
  from datetime import datetime
11
  from typing import List, Dict, Optional, Tuple
12
  from enum import Enum
13
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Query, Form, Depends
14
  from fastapi.responses import StreamingResponse, JSONResponse
15
  from fastapi.middleware.cors import CORSMiddleware
16
+ from fastapi.security import OAuth2PasswordBearer
17
+ from fastapi.encoders import jsonable_encoder
18
  from pydantic import BaseModel
19
  import asyncio
20
  from bson import ObjectId
 
23
  from pydub import AudioSegment
24
  import PyPDF2
25
  import mimetypes
26
+ from docx import Document
27
+ from jose import JWTError, jwt
28
  from txagent.txagent import TxAgent
29
  from db.mongo import get_mongo_client
30
+
 
31
  # Logging
32
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
33
  logger = logging.getLogger("TxAgentAPI")
34
 
35
  # App
36
+ app = FastAPI(title="TxAgent API", version="2.6.0")
37
 
38
+ # CORS
39
  app.add_middleware(
40
  CORSMiddleware,
41
  allow_origins=["*"],
 
44
  allow_headers=["*"]
45
  )
46
 
47
+ # JWT settings (must match CPS-API)
48
+ SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key") # Same as CPS-API
49
+ ALGORITHM = "HS256"
50
+
51
+ # OAuth2 scheme (point to CPS-API's login endpoint)
52
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="https://rocketfarmstudios-cps-api.hf.space/auth/login")
53
+
54
  # Pydantic Models
55
  class ChatRequest(BaseModel):
56
  message: str
 
67
  text: str
68
  language: str = "en"
69
  slow: bool = False
70
+ return_format: str = "mp3"
71
 
72
  # Enums
73
  class RiskLevel(str, Enum):
 
83
  analysis_collection = None
84
  alerts_collection = None
85
 
86
+ # JWT validation
87
+ async def get_current_user(token: str = Depends(oauth2_scheme)):
88
+ credentials_exception = HTTPException(
89
+ status_code=401,
90
+ detail="Could not validate credentials",
91
+ headers={"WWW-Authenticate": "Bearer"},
92
+ )
93
+ try:
94
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
95
+ email: str = payload.get("sub")
96
+ if email is None:
97
+ raise credentials_exception
98
+ except JWTError:
99
+ raise credentials_exception
100
+ user = await users_collection.find_one({"email": email})
101
+ if user is None:
102
+ raise credentials_exception
103
+ return user
104
+
105
+ # Helper functions (unchanged from your original code)
106
  def clean_text_response(text: str) -> str:
107
  text = re.sub(r'\n\s*\n', '\n\n', text)
108
  text = re.sub(r'[ ]+', ' ', text)
 
118
  return ""
119
 
120
  def structure_medical_response(text: str) -> Dict:
 
121
  def extract_improved(text: str, heading: str) -> str:
122
  patterns = [
123
  rf"{re.escape(heading)}:\s*\n(.*?)(?=\n\s*\n|\Z)",
 
125
  rf"{re.escape(heading)}[\s\-]+(.*?)(?=\n\s*\n|\Z)",
126
  rf"\n{re.escape(heading)}\s*\n(.*?)(?=\n\s*\n|\Z)"
127
  ]
 
128
  for pattern in patterns:
129
  match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
130
  if match:
 
132
  content = re.sub(r'^\s*[\-\*]\s*', '', content, flags=re.MULTILINE)
133
  return content
134
  return ""
 
 
135
 
136
+ text = text.replace('**', '').replace('__', '')
137
  return {
138
  "summary": extract_improved(text, "Summary of Patient's Medical History") or
139
  extract_improved(text, "Summarize the patient's medical history"),
 
146
  }
147
 
148
  def detect_suicide_risk(text: str) -> Tuple[RiskLevel, float, List[str]]:
 
149
  suicide_keywords = [
150
  'suicide', 'suicidal', 'kill myself', 'end my life',
151
  'want to die', 'self-harm', 'self harm', 'hopeless',
152
  'no reason to live', 'plan to die'
153
  ]
 
154
  explicit_mentions = [kw for kw in suicide_keywords if kw in text.lower()]
 
155
  if not explicit_mentions:
156
  return RiskLevel.NONE, 0.0, []
157
 
 
170
  temperature=0.2,
171
  max_new_tokens=256
172
  )
 
173
  json_match = re.search(r'\{.*\}', response, re.DOTALL)
174
  if json_match:
175
  assessment = json.loads(json_match.group())
 
189
  return RiskLevel.LOW, risk_score, explicit_mentions
190
 
191
  async def create_alert(patient_id: str, risk_data: dict):
 
192
  alert_doc = {
193
  "patient_id": patient_id,
194
  "type": "suicide_risk",
 
208
  return patient_copy
209
 
210
  def compute_patient_data_hash(data: dict) -> str:
 
211
  serialized = json.dumps(data, sort_keys=True)
212
  return hashlib.sha256(serialized.encode()).hexdigest()
213
 
214
  def compute_file_content_hash(file_content: bytes) -> str:
 
215
  return hashlib.sha256(file_content).hexdigest()
216
 
217
  def extract_text_from_pdf(pdf_data: bytes) -> str:
 
218
  try:
219
  pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_data))
220
  text = ""
 
226
  raise HTTPException(status_code=400, detail="Failed to extract text from PDF")
227
 
228
  async def analyze_patient_report(patient_id: Optional[str], report_content: str, file_type: str, file_content: bytes):
229
+ identifier = patient_id if patient_id else compute_file_content_hash(file_content)
230
+ report_data = {"identifier": identifier, "content": report_content, "file_type": file_type}
231
+ report_hash = compute_patient_data_hash(report_data)
232
+ logger.info(f"🧾 Analyzing report for identifier: {identifier}")
233
+
234
+ existing_analysis = await analysis_collection.find_one({"identifier": identifier, "report_hash": report_hash})
235
+ if existing_analysis:
236
+ logger.info(f"✅ No changes in report data for {identifier}, skipping analysis")
237
+ return existing_analysis
238
+
239
+ prompt = (
240
+ "You are a clinical decision support AI. Analyze the following patient report:\n"
241
+ "1. Summarize the patient's medical history.\n"
242
+ "2. Identify risks or red flags (including mental health and suicide risk).\n"
243
+ "3. Highlight missed diagnoses or treatments.\n"
244
+ "4. Suggest next clinical steps.\n"
245
+ f"\nPatient Report ({file_type}):\n{'-'*40}\n{report_content[:10000]}"
246
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
+ raw_response = agent.chat(
249
+ message=prompt,
250
+ history=[],
251
+ temperature=0.7,
252
+ max_new_tokens=1024
253
+ )
254
+ structured_response = structure_medical_response(raw_response)
255
 
256
+ risk_level, risk_score, risk_factors = detect_suicide_risk(raw_response)
257
+ suicide_risk = {
258
+ "level": risk_level.value,
259
+ "score": risk_score,
260
+ "factors": risk_factors
261
+ }
 
 
 
 
 
262
 
263
+ analysis_doc = {
264
+ "identifier": identifier,
265
+ "patient_id": patient_id,
266
+ "timestamp": datetime.utcnow(),
267
+ "summary": structured_response,
268
+ "suicide_risk": suicide_risk,
269
+ "raw": raw_response,
270
+ "report_hash": report_hash,
271
+ "file_type": file_type
272
+ }
273
 
274
+ await analysis_collection.update_one(
275
+ {"identifier": identifier, "report_hash": report_hash},
276
+ {"$set": analysis_doc},
277
+ upsert=True
278
+ )
279
 
280
+ if patient_id and risk_level in [RiskLevel.MODERATE, RiskLevel.HIGH, RiskLevel.SEVERE]:
281
+ await create_alert(patient_id, suicide_risk)
282
 
283
+ logger.info(f"✅ Stored analysis for identifier {identifier}")
284
+ return analysis_doc
 
285
 
286
  async def analyze_all_patients():
 
287
  patients = await patients_collection.find({}).to_list(length=None)
288
  for patient in patients:
289
  await analyze_patient(patient)
290
  await asyncio.sleep(0.1)
291
 
292
  async def analyze_patient(patient: dict):
 
293
  try:
294
  serialized = serialize_patient(patient)
295
  patient_id = serialized.get("fhir_id")
 
347
  logger.error(f"Error analyzing patient: {e}")
348
 
349
  def recognize_speech(audio_data: bytes, language: str = "en-US") -> str:
 
350
  recognizer = sr.Recognizer()
 
351
  try:
352
  with io.BytesIO(audio_data) as audio_file:
353
  with sr.AudioFile(audio_file) as source:
 
365
  raise HTTPException(status_code=500, detail="Error processing speech")
366
 
367
  def text_to_speech(text: str, language: str = "en", slow: bool = False) -> bytes:
 
368
  try:
369
  tts = gTTS(text=text, lang=language, slow=slow)
370
  mp3_fp = io.BytesIO()
 
396
  logger.info("✅ TxAgent initialized")
397
 
398
  db = get_mongo_client()["cps_db"]
399
+ global users_collection # Add this to access users_collection for authentication
400
+ users_collection = db["users"]
401
  patients_collection = db["patients"]
402
  analysis_collection = db["patient_analysis_results"]
403
  alerts_collection = db["clinical_alerts"]
 
405
 
406
  asyncio.create_task(analyze_all_patients())
407
 
408
+ # Protected Endpoints (add Depends(get_current_user) to all endpoints)
409
  @app.get("/status")
410
+ async def status(current_user: dict = Depends(get_current_user)):
411
+ logger.info(f"Status endpoint accessed by {current_user['email']}")
412
  return {
413
  "status": "running",
414
  "timestamp": datetime.utcnow().isoformat(),
 
417
  }
418
 
419
  @app.get("/patients/analysis-results")
420
+ async def get_patient_analysis_results(
421
+ name: Optional[str] = Query(None),
422
+ current_user: dict = Depends(get_current_user)
423
+ ):
424
+ logger.info(f"Fetching analysis results by {current_user['email']}")
425
  try:
426
  query = {}
427
  if name:
 
448
  raise HTTPException(status_code=500, detail="Failed to retrieve analysis results")
449
 
450
  @app.post("/chat-stream")
451
+ async def chat_stream_endpoint(
452
+ request: ChatRequest,
453
+ current_user: dict = Depends(get_current_user)
454
+ ):
455
+ logger.info(f"Chat stream initiated by {current_user['email']}")
456
  async def token_stream():
457
  try:
458
  conversation = [{"role": "system", "content": agent.chat_prompt}]
 
486
  @app.post("/voice/transcribe")
487
  async def transcribe_voice(
488
  audio: UploadFile = File(...),
489
+ language: str = Query("en-US", description="Language code for speech recognition"),
490
+ current_user: dict = Depends(get_current_user)
491
  ):
492
+ logger.info(f"Voice transcription initiated by {current_user['email']}")
493
  try:
494
  audio_data = await audio.read()
495
  if not audio.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')):
 
505
  raise HTTPException(status_code=500, detail="Error processing voice input")
506
 
507
  @app.post("/voice/synthesize")
508
+ async def synthesize_voice(
509
+ request: VoiceOutputRequest,
510
+ current_user: dict = Depends(get_current_user)
511
+ ):
512
+ logger.info(f"Voice synthesis initiated by {current_user['email']}")
513
  try:
514
  audio_data = text_to_speech(request.text, request.language, request.slow)
515
 
 
533
  audio: UploadFile = File(...),
534
  language: str = Query("en-US", description="Language code for speech recognition"),
535
  temperature: float = Query(0.7, ge=0.1, le=1.0),
536
+ max_new_tokens: int = Query(512, ge=50, le=1024),
537
+ current_user: dict = Depends(get_current_user)
538
  ):
539
+ logger.info(f"Voice chat initiated by {current_user['email']}")
540
  try:
541
  audio_data = await audio.read()
542
  user_message = recognize_speech(audio_data, language)
 
567
  file: UploadFile = File(...),
568
  patient_id: Optional[str] = Form(None),
569
  temperature: float = Form(0.5),
570
+ max_new_tokens: int = Form(1024),
571
+ current_user: dict = Depends(get_current_user)
572
  ):
573
+ logger.info(f"Report analysis initiated by {current_user['email']}")
 
 
 
 
 
 
 
574
  try:
 
575
  content_type = file.content_type
576
  allowed_types = [
577
  'application/pdf',
 
585
  detail=f"Unsupported file type: {content_type}. Supported types: PDF, TXT, DOCX"
586
  )
587
 
 
588
  file_content = await file.read()
589
 
 
590
  if content_type == 'application/pdf':
591
  text = extract_text_from_pdf(file_content)
592
  elif content_type == 'text/plain':
 
597
  else:
598
  raise HTTPException(status_code=400, detail="Unsupported file type")
599
 
 
600
  text = clean_text_response(text)
601
  if len(text.strip()) < 50:
602
  raise HTTPException(
 
604
  detail="Extracted text is too short (minimum 50 characters required)"
605
  )
606
 
 
607
  analysis = await analyze_patient_report(
608
  patient_id=patient_id,
609
  report_content=text,
 
611
  file_content=file_content
612
  )
613
 
 
614
  if "_id" in analysis and isinstance(analysis["_id"], ObjectId):
615
  analysis["_id"] = str(analysis["_id"])
616
  if "timestamp" in analysis and isinstance(analysis["timestamp"], datetime):
617
  analysis["timestamp"] = analysis["timestamp"].isoformat()
618
 
 
619
  return JSONResponse(content=jsonable_encoder({
620
  "status": "success",
621
  "analysis": analysis,
 
633
  detail=f"Failed to analyze report: {str(e)}"
634
  )
635
 
 
636
  if __name__ == "__main__":
637
  import uvicorn
638
  uvicorn.run(app, host="0.0.0.0", port=8000)