Ali2206 commited on
Commit
62d835c
·
verified ·
1 Parent(s): ac628ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -622
app.py CHANGED
@@ -1,41 +1,13 @@
1
- # app.py (in TxAgent-API)
2
- import os
3
- import sys
4
- import json
5
- import logging
6
- import re
7
- import hashlib
8
- import io
9
- import base64
10
- from datetime import datetime
11
- from typing import List, Dict, Optional, Tuple
12
- from enum import Enum
13
- from fastapi import FastAPI, HTTPException, UploadFile, File, Query, Form, Depends
14
- from fastapi.responses import StreamingResponse, JSONResponse
15
  from fastapi.middleware.cors import CORSMiddleware
16
- from fastapi.security import OAuth2PasswordBearer
17
- from fastapi.encoders import jsonable_encoder
18
- from pydantic import BaseModel
19
- import asyncio
20
- from bson import ObjectId
21
- import speech_recognition as sr
22
- from gtts import gTTS
23
- from pydub import AudioSegment
24
- import PyPDF2
25
- import mimetypes
26
- from docx import Document
27
- from jose import JWTError, jwt
28
- from txagent.txagent import TxAgent
29
- from db.mongo import get_mongo_client
30
 
31
- # Logging
32
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
33
- logger = logging.getLogger("TxAgentAPI")
34
-
35
- # App
36
  app = FastAPI(title="TxAgent API", version="2.6.0")
37
 
38
- # CORS
39
  app.add_middleware(
40
  CORSMiddleware,
41
  allow_origins=["*"],
@@ -44,595 +16,11 @@ app.add_middleware(
44
  allow_headers=["*"]
45
  )
46
 
47
- # JWT settings (must match CPS-API)
48
- SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key") # Same as CPS-API
49
- ALGORITHM = "HS256"
50
-
51
- # OAuth2 scheme (point to CPS-API's login endpoint)
52
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="https://rocketfarmstudios-cps-api.hf.space/auth/login")
53
-
54
- # Pydantic Models
55
- class ChatRequest(BaseModel):
56
- message: str
57
- temperature: float = 0.7
58
- max_new_tokens: int = 512
59
- history: Optional[List[Dict]] = None
60
- format: Optional[str] = "clean"
61
-
62
- class VoiceInputRequest(BaseModel):
63
- audio_format: str = "wav"
64
- language: str = "en-US"
65
-
66
- class VoiceOutputRequest(BaseModel):
67
- text: str
68
- language: str = "en"
69
- slow: bool = False
70
- return_format: str = "mp3"
71
-
72
- # Enums
73
- class RiskLevel(str, Enum):
74
- NONE = "none"
75
- LOW = "low"
76
- MODERATE = "moderate"
77
- HIGH = "high"
78
- SEVERE = "severe"
79
-
80
- # Globals
81
- agent = None
82
- patients_collection = None
83
- analysis_collection = None
84
- alerts_collection = None
85
-
86
- # JWT validation
87
- async def get_current_user(token: str = Depends(oauth2_scheme)):
88
- credentials_exception = HTTPException(
89
- status_code=401,
90
- detail="Could not validate credentials",
91
- headers={"WWW-Authenticate": "Bearer"},
92
- )
93
- try:
94
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
95
- email: str = payload.get("sub")
96
- if email is None:
97
- raise credentials_exception
98
- except JWTError:
99
- raise credentials_exception
100
- user = await users_collection.find_one({"email": email})
101
- if user is None:
102
- raise credentials_exception
103
- return user
104
-
105
- # Helper functions (unchanged from your original code)
106
- def clean_text_response(text: str) -> str:
107
- text = re.sub(r'\n\s*\n', '\n\n', text)
108
- text = re.sub(r'[ ]+', ' ', text)
109
- return text.replace("**", "").replace("__", "").strip()
110
-
111
- def extract_section(text: str, heading: str) -> str:
112
- try:
113
- pattern = rf"{re.escape(heading)}:\s*\n(.*?)(?=\n[A-Z][^\n]*:|\Z)"
114
- match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
115
- return match.group(1).strip() if match else ""
116
- except Exception as e:
117
- logger.error(f"Section extraction failed for heading '{heading}': {e}")
118
- return ""
119
-
120
- def structure_medical_response(text: str) -> Dict:
121
- def extract_improved(text: str, heading: str) -> str:
122
- patterns = [
123
- rf"{re.escape(heading)}:\s*\n(.*?)(?=\n\s*\n|\Z)",
124
- rf"\*\*{re.escape(heading)}\*\*:\s*\n(.*?)(?=\n\s*\n|\Z)",
125
- rf"{re.escape(heading)}[\s\-]+(.*?)(?=\n\s*\n|\Z)",
126
- rf"\n{re.escape(heading)}\s*\n(.*?)(?=\n\s*\n|\Z)"
127
- ]
128
- for pattern in patterns:
129
- match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
130
- if match:
131
- content = match.group(1).strip()
132
- content = re.sub(r'^\s*[\-\*]\s*', '', content, flags=re.MULTILINE)
133
- return content
134
- return ""
135
-
136
- text = text.replace('**', '').replace('__', '')
137
- return {
138
- "summary": extract_improved(text, "Summary of Patient's Medical History") or
139
- extract_improved(text, "Summarize the patient's medical history"),
140
- "risks": extract_improved(text, "Identify Risks or Red Flags") or
141
- extract_improved(text, "Risks or Red Flags"),
142
- "missed_issues": extract_improved(text, "Missed Diagnoses or Treatments") or
143
- extract_improved(text, "What the doctor might have missed"),
144
- "recommendations": extract_improved(text, "Suggest Next Clinical Steps") or
145
- extract_improved(text, "Suggested Clinical Actions")
146
- }
147
-
148
- def detect_suicide_risk(text: str) -> Tuple[RiskLevel, float, List[str]]:
149
- suicide_keywords = [
150
- 'suicide', 'suicidal', 'kill myself', 'end my life',
151
- 'want to die', 'self-harm', 'self harm', 'hopeless',
152
- 'no reason to live', 'plan to die'
153
- ]
154
- explicit_mentions = [kw for kw in suicide_keywords if kw in text.lower()]
155
- if not explicit_mentions:
156
- return RiskLevel.NONE, 0.0, []
157
-
158
- assessment_prompt = (
159
- "Assess the suicide risk level based on this text. "
160
- "Consider frequency, specificity, and severity of statements. "
161
- "Respond with JSON format: {\"risk_level\": \"low/moderate/high/severe\", "
162
- "\"risk_score\": 0-1, \"factors\": [\"list of risk factors\"]}\n\n"
163
- f"Text to assess:\n{text}"
164
- )
165
-
166
- try:
167
- response = agent.chat(
168
- message=assessment_prompt,
169
- history=[],
170
- temperature=0.2,
171
- max_new_tokens=256
172
- )
173
- json_match = re.search(r'\{.*\}', response, re.DOTALL)
174
- if json_match:
175
- assessment = json.loads(json_match.group())
176
- return (
177
- RiskLevel(assessment.get("risk_level", "none").lower()),
178
- float(assessment.get("risk_score", 0)),
179
- assessment.get("factors", [])
180
- )
181
- except Exception as e:
182
- logger.error(f"Error in suicide risk assessment: {e}")
183
-
184
- risk_score = min(0.1 * len(explicit_mentions), 0.9)
185
- if risk_score > 0.7:
186
- return RiskLevel.HIGH, risk_score, explicit_mentions
187
- elif risk_score > 0.4:
188
- return RiskLevel.MODERATE, risk_score, explicit_mentions
189
- return RiskLevel.LOW, risk_score, explicit_mentions
190
-
191
- async def create_alert(patient_id: str, risk_data: dict):
192
- alert_doc = {
193
- "patient_id": patient_id,
194
- "type": "suicide_risk",
195
- "level": risk_data["level"],
196
- "score": risk_data["score"],
197
- "factors": risk_data["factors"],
198
- "timestamp": datetime.utcnow(),
199
- "acknowledged": False
200
- }
201
- await alerts_collection.insert_one(alert_doc)
202
- logger.warning(f"⚠️ Created suicide risk alert for patient {patient_id}")
203
-
204
- def serialize_patient(patient: dict) -> dict:
205
- patient_copy = patient.copy()
206
- if "_id" in patient_copy:
207
- patient_copy["_id"] = str(patient_copy["_id"])
208
- return patient_copy
209
-
210
- def compute_patient_data_hash(data: dict) -> str:
211
- serialized = json.dumps(data, sort_keys=True)
212
- return hashlib.sha256(serialized.encode()).hexdigest()
213
-
214
- def compute_file_content_hash(file_content: bytes) -> str:
215
- return hashlib.sha256(file_content).hexdigest()
216
-
217
- def extract_text_from_pdf(pdf_data: bytes) -> str:
218
- try:
219
- pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_data))
220
- text = ""
221
- for page in pdf_reader.pages:
222
- text += page.extract_text() or ""
223
- return clean_text_response(text)
224
- except Exception as e:
225
- logger.error(f"Error extracting text from PDF: {e}")
226
- raise HTTPException(status_code=400, detail="Failed to extract text from PDF")
227
-
228
- async def analyze_patient_report(patient_id: Optional[str], report_content: str, file_type: str, file_content: bytes):
229
- identifier = patient_id if patient_id else compute_file_content_hash(file_content)
230
- report_data = {"identifier": identifier, "content": report_content, "file_type": file_type}
231
- report_hash = compute_patient_data_hash(report_data)
232
- logger.info(f"🧾 Analyzing report for identifier: {identifier}")
233
-
234
- existing_analysis = await analysis_collection.find_one({"identifier": identifier, "report_hash": report_hash})
235
- if existing_analysis:
236
- logger.info(f"✅ No changes in report data for {identifier}, skipping analysis")
237
- return existing_analysis
238
-
239
- prompt = (
240
- "You are a clinical decision support AI. Analyze the following patient report:\n"
241
- "1. Summarize the patient's medical history.\n"
242
- "2. Identify risks or red flags (including mental health and suicide risk).\n"
243
- "3. Highlight missed diagnoses or treatments.\n"
244
- "4. Suggest next clinical steps.\n"
245
- f"\nPatient Report ({file_type}):\n{'-'*40}\n{report_content[:10000]}"
246
- )
247
-
248
- raw_response = agent.chat(
249
- message=prompt,
250
- history=[],
251
- temperature=0.7,
252
- max_new_tokens=1024
253
- )
254
- structured_response = structure_medical_response(raw_response)
255
-
256
- risk_level, risk_score, risk_factors = detect_suicide_risk(raw_response)
257
- suicide_risk = {
258
- "level": risk_level.value,
259
- "score": risk_score,
260
- "factors": risk_factors
261
- }
262
-
263
- analysis_doc = {
264
- "identifier": identifier,
265
- "patient_id": patient_id,
266
- "timestamp": datetime.utcnow(),
267
- "summary": structured_response,
268
- "suicide_risk": suicide_risk,
269
- "raw": raw_response,
270
- "report_hash": report_hash,
271
- "file_type": file_type
272
- }
273
-
274
- await analysis_collection.update_one(
275
- {"identifier": identifier, "report_hash": report_hash},
276
- {"$set": analysis_doc},
277
- upsert=True
278
- )
279
-
280
- if patient_id and risk_level in [RiskLevel.MODERATE, RiskLevel.HIGH, RiskLevel.SEVERE]:
281
- await create_alert(patient_id, suicide_risk)
282
-
283
- logger.info(f"✅ Stored analysis for identifier {identifier}")
284
- return analysis_doc
285
-
286
- async def analyze_all_patients():
287
- patients = await patients_collection.find({}).to_list(length=None)
288
- for patient in patients:
289
- await analyze_patient(patient)
290
- await asyncio.sleep(0.1)
291
-
292
- async def analyze_patient(patient: dict):
293
- try:
294
- serialized = serialize_patient(patient)
295
- patient_id = serialized.get("fhir_id")
296
- patient_hash = compute_patient_data_hash(serialized)
297
- logger.info(f"🧾 Analyzing patient: {patient_id}")
298
-
299
- existing_analysis = await analysis_collection.find_one({"patient_id": patient_id})
300
- if existing_analysis and existing_analysis.get("data_hash") == patient_hash:
301
- logger.info(f"✅ No changes in patient data for {patient_id}, skipping analysis")
302
- return
303
-
304
- doc = json.dumps(serialized, indent=2)
305
- message = (
306
- "You are a clinical decision support AI.\n\n"
307
- "Given the patient document below:\n"
308
- "1. Summarize the patient's medical history.\n"
309
- "2. Identify risks or red flags (including mental health and suicide risk).\n"
310
- "3. Highlight missed diagnoses or treatments.\n"
311
- "4. Suggest next clinical steps.\n"
312
- f"\nPatient Document:\n{'-'*40}\n{doc[:10000]}"
313
- )
314
-
315
- raw = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024)
316
- structured = structure_medical_response(raw)
317
-
318
- risk_level, risk_score, risk_factors = detect_suicide_risk(raw)
319
- suicide_risk = {
320
- "level": risk_level.value,
321
- "score": risk_score,
322
- "factors": risk_factors
323
- }
324
-
325
- analysis_doc = {
326
- "identifier": patient_id,
327
- "patient_id": patient_id,
328
- "timestamp": datetime.utcnow(),
329
- "summary": structured,
330
- "suicide_risk": suicide_risk,
331
- "raw": raw,
332
- "data_hash": patient_hash
333
- }
334
-
335
- await analysis_collection.update_one(
336
- {"identifier": patient_id},
337
- {"$set": analysis_doc},
338
- upsert=True
339
- )
340
-
341
- if risk_level in [RiskLevel.MODERATE, RiskLevel.HIGH, RiskLevel.SEVERE]:
342
- await create_alert(patient_id, suicide_risk)
343
-
344
- logger.info(f"✅ Stored analysis for patient {patient_id}")
345
-
346
- except Exception as e:
347
- logger.error(f"Error analyzing patient: {e}")
348
-
349
- def recognize_speech(audio_data: bytes, language: str = "en-US") -> str:
350
- recognizer = sr.Recognizer()
351
- try:
352
- with io.BytesIO(audio_data) as audio_file:
353
- with sr.AudioFile(audio_file) as source:
354
- audio = recognizer.record(source)
355
- text = recognizer.recognize_google(audio, language=language)
356
- return text
357
- except sr.UnknownValueError:
358
- logger.error("Google Speech Recognition could not understand audio")
359
- raise HTTPException(status_code=400, detail="Could not understand audio")
360
- except sr.RequestError as e:
361
- logger.error(f"Could not request results from Google Speech Recognition service; {e}")
362
- raise HTTPException(status_code=503, detail="Speech recognition service unavailable")
363
- except Exception as e:
364
- logger.error(f"Error in speech recognition: {e}")
365
- raise HTTPException(status_code=500, detail="Error processing speech")
366
-
367
- def text_to_speech(text: str, language: str = "en", slow: bool = False) -> bytes:
368
- try:
369
- tts = gTTS(text=text, lang=language, slow=slow)
370
- mp3_fp = io.BytesIO()
371
- tts.write_to_fp(mp3_fp)
372
- mp3_fp.seek(0)
373
- return mp3_fp.read()
374
- except Exception as e:
375
- logger.error(f"Error in text-to-speech conversion: {e}")
376
- raise HTTPException(status_code=500, detail="Error generating speech")
377
-
378
- @app.on_event("startup")
379
- async def startup_event():
380
- global agent, patients_collection, analysis_collection, alerts_collection
381
-
382
- agent = TxAgent(
383
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
384
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
385
- enable_finish=True,
386
- enable_rag=False,
387
- force_finish=True,
388
- enable_checker=True,
389
- step_rag_num=4,
390
- seed=42
391
- )
392
- agent.chat_prompt = (
393
- "You are a clinical assistant AI. Analyze the patient's data and provide clear clinical recommendations."
394
- )
395
- agent.init_model()
396
- logger.info("✅ TxAgent initialized")
397
-
398
- db = get_mongo_client()["cps_db"]
399
- global users_collection # Add this to access users_collection for authentication
400
- users_collection = db["users"]
401
- patients_collection = db["patients"]
402
- analysis_collection = db["patient_analysis_results"]
403
- alerts_collection = db["clinical_alerts"]
404
- logger.info("📡 Connected to MongoDB")
405
-
406
- asyncio.create_task(analyze_all_patients())
407
-
408
- # Protected Endpoints (add Depends(get_current_user) to all endpoints)
409
- @app.get("/status")
410
- async def status(current_user: dict = Depends(get_current_user)):
411
- logger.info(f"Status endpoint accessed by {current_user['email']}")
412
- return {
413
- "status": "running",
414
- "timestamp": datetime.utcnow().isoformat(),
415
- "version": "2.6.0",
416
- "features": ["chat", "voice-input", "voice-output", "patient-analysis", "report-upload"]
417
- }
418
-
419
- @app.get("/patients/analysis-results")
420
- async def get_patient_analysis_results(
421
- name: Optional[str] = Query(None),
422
- current_user: dict = Depends(get_current_user)
423
- ):
424
- logger.info(f"Fetching analysis results by {current_user['email']}")
425
- try:
426
- query = {}
427
- if name:
428
- name_regex = re.compile(name, re.IGNORECASE)
429
- matching_patients = await patients_collection.find({"full_name": name_regex}).to_list(length=None)
430
- patient_ids = [p["fhir_id"] for p in matching_patients if "fhir_id" in p]
431
- if not patient_ids:
432
- return []
433
- query = {"patient_id": {"$in": patient_ids}}
434
-
435
- analyses = await analysis_collection.find(query).sort("timestamp", -1).to_list(length=100)
436
- enriched_results = []
437
- for analysis in analyses:
438
- patient = await patients_collection.find_one({"fhir_id": analysis.get("patient_id")})
439
- if patient:
440
- analysis["full_name"] = patient.get("full_name", "Unknown")
441
- analysis["_id"] = str(analysis["_id"])
442
- enriched_results.append(analysis)
443
-
444
- return enriched_results
445
-
446
- except Exception as e:
447
- logger.error(f"Error fetching analysis results: {e}")
448
- raise HTTPException(status_code=500, detail="Failed to retrieve analysis results")
449
-
450
- @app.post("/chat-stream")
451
- async def chat_stream_endpoint(
452
- request: ChatRequest,
453
- current_user: dict = Depends(get_current_user)
454
- ):
455
- logger.info(f"Chat stream initiated by {current_user['email']}")
456
- async def token_stream():
457
- try:
458
- conversation = [{"role": "system", "content": agent.chat_prompt}]
459
- if request.history:
460
- conversation.extend(request.history)
461
- conversation.append({"role": "user", "content": request.message})
462
-
463
- input_ids = agent.tokenizer.apply_chat_template(
464
- conversation, add_generation_prompt=True, return_tensors="pt"
465
- ).to(agent.device)
466
-
467
- output = agent.model.generate(
468
- input_ids,
469
- do_sample=True,
470
- temperature=request.temperature,
471
- max_new_tokens=request.max_new_tokens,
472
- pad_token_id=agent.tokenizer.eos_token_id,
473
- return_dict_in_generate=True
474
- )
475
-
476
- text = agent.tokenizer.decode(output["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True)
477
- for chunk in text.split():
478
- yield chunk + " "
479
- await asyncio.sleep(0.05)
480
- except Exception as e:
481
- logger.error(f"Streaming error: {e}")
482
- yield f"⚠️ Error: {e}"
483
-
484
- return StreamingResponse(token_stream(), media_type="text/plain")
485
-
486
- @app.post("/voice/transcribe")
487
- async def transcribe_voice(
488
- audio: UploadFile = File(...),
489
- language: str = Query("en-US", description="Language code for speech recognition"),
490
- current_user: dict = Depends(get_current_user)
491
- ):
492
- logger.info(f"Voice transcription initiated by {current_user['email']}")
493
- try:
494
- audio_data = await audio.read()
495
- if not audio.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')):
496
- raise HTTPException(status_code=400, detail="Unsupported audio format")
497
-
498
- text = recognize_speech(audio_data, language)
499
- return {"text": text}
500
-
501
- except HTTPException:
502
- raise
503
- except Exception as e:
504
- logger.error(f"Error in voice transcription: {e}")
505
- raise HTTPException(status_code=500, detail="Error processing voice input")
506
-
507
- @app.post("/voice/synthesize")
508
- async def synthesize_voice(
509
- request: VoiceOutputRequest,
510
- current_user: dict = Depends(get_current_user)
511
- ):
512
- logger.info(f"Voice synthesis initiated by {current_user['email']}")
513
- try:
514
- audio_data = text_to_speech(request.text, request.language, request.slow)
515
-
516
- if request.return_format == "base64":
517
- return {"audio": base64.b64encode(audio_data).decode('utf-8')}
518
- else:
519
- return StreamingResponse(
520
- io.BytesIO(audio_data),
521
- media_type="audio/mpeg",
522
- headers={"Content-Disposition": "attachment; filename=speech.mp3"}
523
- )
524
-
525
- except HTTPException:
526
- raise
527
- except Exception as e:
528
- logger.error(f"Error in voice synthesis: {e}")
529
- raise HTTPException(status_code=500, detail="Error generating voice output")
530
-
531
- @app.post("/voice/chat")
532
- async def voice_chat_endpoint(
533
- audio: UploadFile = File(...),
534
- language: str = Query("en-US", description="Language code for speech recognition"),
535
- temperature: float = Query(0.7, ge=0.1, le=1.0),
536
- max_new_tokens: int = Query(512, ge=50, le=1024),
537
- current_user: dict = Depends(get_current_user)
538
- ):
539
- logger.info(f"Voice chat initiated by {current_user['email']}")
540
- try:
541
- audio_data = await audio.read()
542
- user_message = recognize_speech(audio_data, language)
543
-
544
- chat_response = agent.chat(
545
- message=user_message,
546
- history=[],
547
- temperature=temperature,
548
- max_new_tokens=max_new_tokens
549
- )
550
-
551
- audio_data = text_to_speech(chat_response, language.split('-')[0])
552
-
553
- return StreamingResponse(
554
- io.BytesIO(audio_data),
555
- media_type="audio/mpeg",
556
- headers={"Content-Disposition": "attachment; filename=response.mp3"}
557
- )
558
-
559
- except HTTPException:
560
- raise
561
- except Exception as e:
562
- logger.error(f"Error in voice chat: {e}")
563
- raise HTTPException(status_code=500, detail="Error processing voice chat")
564
-
565
- @app.post("/analyze-report")
566
- async def analyze_clinical_report(
567
- file: UploadFile = File(...),
568
- patient_id: Optional[str] = Form(None),
569
- temperature: float = Form(0.5),
570
- max_new_tokens: int = Form(1024),
571
- current_user: dict = Depends(get_current_user)
572
- ):
573
- logger.info(f"Report analysis initiated by {current_user['email']}")
574
- try:
575
- content_type = file.content_type
576
- allowed_types = [
577
- 'application/pdf',
578
- 'text/plain',
579
- 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
580
- ]
581
-
582
- if content_type not in allowed_types:
583
- raise HTTPException(
584
- status_code=400,
585
- detail=f"Unsupported file type: {content_type}. Supported types: PDF, TXT, DOCX"
586
- )
587
-
588
- file_content = await file.read()
589
-
590
- if content_type == 'application/pdf':
591
- text = extract_text_from_pdf(file_content)
592
- elif content_type == 'text/plain':
593
- text = file_content.decode('utf-8')
594
- elif content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document':
595
- doc = Document(io.BytesIO(file_content))
596
- text = "\n".join([para.text for para in doc.paragraphs])
597
- else:
598
- raise HTTPException(status_code=400, detail="Unsupported file type")
599
-
600
- text = clean_text_response(text)
601
- if len(text.strip()) < 50:
602
- raise HTTPException(
603
- status_code=400,
604
- detail="Extracted text is too short (minimum 50 characters required)"
605
- )
606
-
607
- analysis = await analyze_patient_report(
608
- patient_id=patient_id,
609
- report_content=text,
610
- file_type=content_type,
611
- file_content=file_content
612
- )
613
-
614
- if "_id" in analysis and isinstance(analysis["_id"], ObjectId):
615
- analysis["_id"] = str(analysis["_id"])
616
- if "timestamp" in analysis and isinstance(analysis["timestamp"], datetime):
617
- analysis["timestamp"] = analysis["timestamp"].isoformat()
618
-
619
- return JSONResponse(content=jsonable_encoder({
620
- "status": "success",
621
- "analysis": analysis,
622
- "patient_id": patient_id,
623
- "file_type": content_type,
624
- "file_size": len(file_content)
625
- }))
626
 
627
- except HTTPException:
628
- raise
629
- except Exception as e:
630
- logger.error(f"Error in report analysis: {str(e)}")
631
- raise HTTPException(
632
- status_code=500,
633
- detail=f"Failed to analyze report: {str(e)}"
634
- )
635
 
636
  if __name__ == "__main__":
637
- import uvicorn
638
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ import uvicorn
2
+ from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ from config import setup_app
5
+ from endpoints import router
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ # Create the FastAPI app
 
 
 
 
8
  app = FastAPI(title="TxAgent API", version="2.6.0")
9
 
10
+ # Apply CORS middleware
11
  app.add_middleware(
12
  CORSMiddleware,
13
  allow_origins=["*"],
 
16
  allow_headers=["*"]
17
  )
18
 
19
+ # Include the router with endpoints
20
+ app.include_router(router)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Setup the app (e.g., initialize globals, startup event)
23
+ setup_app(app)
 
 
 
 
 
 
24
 
25
  if __name__ == "__main__":
 
26
  uvicorn.run(app, host="0.0.0.0", port=8000)