|
import os |
|
import sys |
|
import json |
|
import re |
|
import logging |
|
from fastapi import FastAPI, HTTPException, UploadFile, File |
|
from fastapi.responses import JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from typing import List, Dict, Optional |
|
from datetime import datetime |
|
from pydantic import BaseModel |
|
import markdown |
|
import PyPDF2 |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger("TxAgentAPI") |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
src_path = os.path.abspath(os.path.join(current_dir, "src")) |
|
sys.path.insert(0, src_path) |
|
|
|
|
|
try: |
|
from txagent.txagent import TxAgent |
|
except ImportError as e: |
|
logger.error(f"Failed to import TxAgent: {str(e)}") |
|
raise |
|
|
|
|
|
app = FastAPI( |
|
title="TxAgent API", |
|
description="API for TxAgent medical document analysis", |
|
version="2.0.0" |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
message: str |
|
temperature: float = 0.7 |
|
max_new_tokens: int = 512 |
|
history: Optional[List[Dict]] = None |
|
format: Optional[str] = "clean" |
|
|
|
|
|
def clean_text_response(text: str) -> str: |
|
"""Basic text cleaning""" |
|
text = re.sub(r'\n\s*\n', '\n\n', text) |
|
text = re.sub(r'[ ]+', ' ', text) |
|
text = text.replace("**", "").replace("__", "") |
|
return text.strip() |
|
|
|
def structure_medical_response(text: str) -> Dict: |
|
"""Structure medical content into categories""" |
|
result = {"overview": "", "symptoms": [], "types": {}, "notes": ""} |
|
overview_end = text.find("Type 1 Diabetes:") |
|
result["overview"] = clean_text_response(text[:overview_end]) |
|
type_sections = re.split(r'(Type \d Diabetes:)', text[overview_end:]) |
|
current_type = None |
|
for section in type_sections: |
|
if section.startswith("Type "): |
|
current_type = section.replace(":", "").strip().lower() |
|
result["types"][current_type] = [] |
|
elif current_type: |
|
points = [p.strip() for p in section.split('\n') if p.strip()] |
|
result["types"][current_type] = points |
|
return result |
|
|
|
|
|
agent = None |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
global agent |
|
try: |
|
logger.info("Initializing TxAgent...") |
|
agent = TxAgent( |
|
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
tool_files_dict={}, |
|
enable_finish=True, |
|
enable_rag=False, |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=4, |
|
seed=100 |
|
) |
|
agent.init_model() |
|
logger.info("TxAgent initialized successfully") |
|
except Exception as e: |
|
logger.error(f"Failed to initialize agent: {str(e)}") |
|
raise RuntimeError(f"Failed to initialize agent: {str(e)}") |
|
|
|
@app.post("/chat") |
|
async def chat_endpoint(request: ChatRequest): |
|
"""Handle chat conversations with formatting options""" |
|
try: |
|
logger.info(f"Chat request received (format: {request.format})") |
|
raw_response = agent.chat( |
|
message=request.message, |
|
history=request.history, |
|
temperature=request.temperature, |
|
max_new_tokens=request.max_new_tokens |
|
) |
|
formatted_response = { |
|
"raw": raw_response, |
|
"clean": clean_text_response(raw_response), |
|
"structured": structure_medical_response(raw_response), |
|
"html": markdown.markdown(raw_response) |
|
} |
|
response_content = formatted_response.get(request.format, formatted_response["clean"]) |
|
return JSONResponse({ |
|
"status": "success", |
|
"format": request.format, |
|
"response": response_content, |
|
"timestamp": datetime.now().isoformat(), |
|
"available_formats": list(formatted_response.keys()) |
|
}) |
|
except Exception as e: |
|
logger.error(f"Chat error: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/upload") |
|
async def upload_file(file: UploadFile = File(...)): |
|
"""Handle file uploads and process with TxAgent""" |
|
try: |
|
logger.info(f"File upload received: {file.filename}") |
|
content = "" |
|
if file.filename.endswith('.pdf'): |
|
pdf_reader = PyPDF2.PdfReader(file.file) |
|
for page in pdf_reader.pages: |
|
content += page.extract_text() or "" |
|
else: |
|
content = await file.read() |
|
content = content.decode('utf-8', errors='ignore') |
|
|
|
message = f"Analyze the following medical document content:\n\n{content[:10000]}" |
|
raw_response = agent.chat( |
|
message=message, |
|
history=[], |
|
temperature=0.7, |
|
max_new_tokens=512 |
|
) |
|
formatted_response = { |
|
"raw": raw_response, |
|
"clean": clean_text_response(raw_response), |
|
"structured": structure_medical_response(raw_response), |
|
"html": markdown.markdown(raw_response) |
|
} |
|
response_content = formatted_response["clean"] |
|
return JSONResponse({ |
|
"status": "success", |
|
"format": "clean", |
|
"response": response_content, |
|
"timestamp": datetime.now().isoformat(), |
|
"available_formats": list(formatted_response.keys()) |
|
}) |
|
except Exception as e: |
|
logger.error(f"File upload error: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
finally: |
|
file.file.close() |
|
|
|
@app.get("/status") |
|
async def service_status(): |
|
"""Check service status""" |
|
return { |
|
"status": "running", |
|
"version": "2.0.0", |
|
"model": agent.model_name if agent else "not loaded", |
|
"formats_available": ["raw", "clean", "structured", "html"], |
|
"timestamp": datetime.now().isoformat() |
|
} |