File size: 6,274 Bytes
f126604 d377221 f126604 5620229 7e095f4 064da3f 7757822 f126604 d377221 7757822 5620229 064da3f f126604 5620229 7e095f4 bdcc052 7e095f4 520104a 7e095f4 520104a 32e4e6a 7757822 5620229 32e4e6a f126604 5620229 f126604 7e095f4 f3089ba 7e095f4 5620229 7e095f4 f3089ba 7e095f4 bdcc052 7e095f4 bdcc052 f3089ba bdcc052 7757822 5620229 7757822 5620229 7757822 5620229 7757822 5620229 7757822 bdcc052 f126604 064da3f ab0089d 064da3f f126604 7757822 bdcc052 f3089ba 5620229 d377221 5620229 7e095f4 bdcc052 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import os
import sys
import json
import re
import logging
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Optional
from datetime import datetime
from pydantic import BaseModel
import markdown
import PyPDF2
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("TxAgentAPI")
# Add src directory to Python path
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)
# Import TxAgent after setting up path
try:
from txagent.txagent import TxAgent
except ImportError as e:
logger.error(f"Failed to import TxAgent: {str(e)}")
raise
# Initialize FastAPI app
app = FastAPI(
title="TxAgent API",
description="API for TxAgent medical document analysis",
version="2.0.0"
)
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request models
class ChatRequest(BaseModel):
message: str
temperature: float = 0.7
max_new_tokens: int = 512
history: Optional[List[Dict]] = None
format: Optional[str] = "clean" # Options: raw, clean, structured, html
# Response cleaning functions
def clean_text_response(text: str) -> str:
"""Basic text cleaning"""
text = re.sub(r'\n\s*\n', '\n\n', text)
text = re.sub(r'[ ]+', ' ', text)
text = text.replace("**", "").replace("__", "")
return text.strip()
def structure_medical_response(text: str) -> Dict:
"""Structure medical content into categories"""
result = {"overview": "", "symptoms": [], "types": {}, "notes": ""}
overview_end = text.find("Type 1 Diabetes:")
result["overview"] = clean_text_response(text[:overview_end])
type_sections = re.split(r'(Type \d Diabetes:)', text[overview_end:])
current_type = None
for section in type_sections:
if section.startswith("Type "):
current_type = section.replace(":", "").strip().lower()
result["types"][current_type] = []
elif current_type:
points = [p.strip() for p in section.split('\n') if p.strip()]
result["types"][current_type] = points
return result
# Initialize agent at startup
agent = None
@app.on_event("startup")
async def startup_event():
global agent
try:
logger.info("Initializing TxAgent...")
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={},
enable_finish=True,
enable_rag=False,
force_finish=True,
enable_checker=True,
step_rag_num=4,
seed=100
)
agent.init_model()
logger.info("TxAgent initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize agent: {str(e)}")
raise RuntimeError(f"Failed to initialize agent: {str(e)}")
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
"""Handle chat conversations with formatting options"""
try:
logger.info(f"Chat request received (format: {request.format})")
raw_response = agent.chat(
message=request.message,
history=request.history,
temperature=request.temperature,
max_new_tokens=request.max_new_tokens
)
formatted_response = {
"raw": raw_response,
"clean": clean_text_response(raw_response),
"structured": structure_medical_response(raw_response),
"html": markdown.markdown(raw_response)
}
response_content = formatted_response.get(request.format, formatted_response["clean"])
return JSONResponse({
"status": "success",
"format": request.format,
"response": response_content,
"timestamp": datetime.now().isoformat(),
"available_formats": list(formatted_response.keys())
})
except Exception as e:
logger.error(f"Chat error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
"""Handle file uploads and process with TxAgent"""
try:
logger.info(f"File upload received: {file.filename}")
content = ""
if file.filename.endswith('.pdf'):
pdf_reader = PyPDF2.PdfReader(file.file)
for page in pdf_reader.pages:
content += page.extract_text() or ""
else:
content = await file.read()
content = content.decode('utf-8', errors='ignore')
message = f"Analyze the following medical document content:\n\n{content[:10000]}"
raw_response = agent.chat(
message=message,
history=[],
temperature=0.7,
max_new_tokens=512
)
formatted_response = {
"raw": raw_response,
"clean": clean_text_response(raw_response),
"structured": structure_medical_response(raw_response),
"html": markdown.markdown(raw_response)
}
response_content = formatted_response["clean"]
return JSONResponse({
"status": "success",
"format": "clean",
"response": response_content,
"timestamp": datetime.now().isoformat(),
"available_formats": list(formatted_response.keys())
})
except Exception as e:
logger.error(f"File upload error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
file.file.close()
@app.get("/status")
async def service_status():
"""Check service status"""
return {
"status": "running",
"version": "2.0.0",
"model": agent.model_name if agent else "not loaded",
"formats_available": ["raw", "clean", "structured", "html"],
"timestamp": datetime.now().isoformat()
} |