import os import sys import json import re import logging from fastapi import FastAPI, HTTPException, UploadFile, File from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from typing import List, Dict, Optional from datetime import datetime from pydantic import BaseModel import markdown import PyPDF2 # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger("TxAgentAPI") # Add src directory to Python path current_dir = os.path.dirname(os.path.abspath(__file__)) src_path = os.path.abspath(os.path.join(current_dir, "src")) sys.path.insert(0, src_path) # Import TxAgent after setting up path try: from txagent.txagent import TxAgent except ImportError as e: logger.error(f"Failed to import TxAgent: {str(e)}") raise # Initialize FastAPI app app = FastAPI( title="TxAgent API", description="API for TxAgent medical document analysis", version="2.0.0" ) # CORS configuration app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Request models class ChatRequest(BaseModel): message: str temperature: float = 0.7 max_new_tokens: int = 512 history: Optional[List[Dict]] = None format: Optional[str] = "clean" # Options: raw, clean, structured, html # Response cleaning functions def clean_text_response(text: str) -> str: """Basic text cleaning""" text = re.sub(r'\n\s*\n', '\n\n', text) text = re.sub(r'[ ]+', ' ', text) text = text.replace("**", "").replace("__", "") return text.strip() def structure_medical_response(text: str) -> Dict: """Structure medical content into categories""" result = {"overview": "", "symptoms": [], "types": {}, "notes": ""} overview_end = text.find("Type 1 Diabetes:") result["overview"] = clean_text_response(text[:overview_end]) type_sections = re.split(r'(Type \d Diabetes:)', text[overview_end:]) current_type = None for section in type_sections: if section.startswith("Type "): current_type = section.replace(":", "").strip().lower() result["types"][current_type] = [] elif current_type: points = [p.strip() for p in section.split('\n') if p.strip()] result["types"][current_type] = points return result # Initialize agent at startup agent = None @app.on_event("startup") async def startup_event(): global agent try: logger.info("Initializing TxAgent...") agent = TxAgent( model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", tool_files_dict={}, enable_finish=True, enable_rag=False, force_finish=True, enable_checker=True, step_rag_num=4, seed=100 ) agent.init_model() logger.info("TxAgent initialized successfully") except Exception as e: logger.error(f"Failed to initialize agent: {str(e)}") raise RuntimeError(f"Failed to initialize agent: {str(e)}") @app.post("/chat") async def chat_endpoint(request: ChatRequest): """Handle chat conversations with formatting options""" try: logger.info(f"Chat request received (format: {request.format})") raw_response = agent.chat( message=request.message, history=request.history, temperature=request.temperature, max_new_tokens=request.max_new_tokens ) formatted_response = { "raw": raw_response, "clean": clean_text_response(raw_response), "structured": structure_medical_response(raw_response), "html": markdown.markdown(raw_response) } response_content = formatted_response.get(request.format, formatted_response["clean"]) return JSONResponse({ "status": "success", "format": request.format, "response": response_content, "timestamp": datetime.now().isoformat(), "available_formats": list(formatted_response.keys()) }) except Exception as e: logger.error(f"Chat error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/upload") async def upload_file(file: UploadFile = File(...)): """Handle file uploads and process with TxAgent""" try: logger.info(f"File upload received: {file.filename}") content = "" if file.filename.endswith('.pdf'): pdf_reader = PyPDF2.PdfReader(file.file) for page in pdf_reader.pages: content += page.extract_text() or "" else: content = await file.read() content = content.decode('utf-8', errors='ignore') message = f"Analyze the following medical document content:\n\n{content[:10000]}" raw_response = agent.chat( message=message, history=[], temperature=0.7, max_new_tokens=512 ) formatted_response = { "raw": raw_response, "clean": clean_text_response(raw_response), "structured": structure_medical_response(raw_response), "html": markdown.markdown(raw_response) } response_content = formatted_response["clean"] return JSONResponse({ "status": "success", "format": "clean", "response": response_content, "timestamp": datetime.now().isoformat(), "available_formats": list(formatted_response.keys()) }) except Exception as e: logger.error(f"File upload error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) finally: file.file.close() @app.get("/status") async def service_status(): """Check service status""" return { "status": "running", "version": "2.0.0", "model": agent.model_name if agent else "not loaded", "formats_available": ["raw", "clean", "structured", "html"], "timestamp": datetime.now().isoformat() }