File size: 6,274 Bytes
f126604
d377221
f126604
5620229
7e095f4
064da3f
7757822
f126604
d377221
 
7757822
5620229
064da3f
f126604
5620229
7e095f4
 
bdcc052
7e095f4
 
 
520104a
 
 
 
 
7e095f4
 
 
 
 
 
520104a
32e4e6a
 
7757822
 
5620229
32e4e6a
f126604
 
 
 
 
 
 
 
 
 
5620229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f126604
 
 
 
 
 
 
7e095f4
 
f3089ba
7e095f4
5620229
7e095f4
 
 
 
 
f3089ba
7e095f4
 
bdcc052
7e095f4
bdcc052
f3089ba
bdcc052
7757822
 
5620229
7757822
5620229
 
7757822
 
 
 
 
5620229
 
 
 
 
 
 
7757822
 
5620229
 
 
 
7757822
 
bdcc052
f126604
 
064da3f
 
 
 
 
 
 
 
 
 
 
 
 
 
ab0089d
064da3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f126604
 
7757822
bdcc052
f3089ba
5620229
d377221
5620229
7e095f4
bdcc052
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import sys
import json
import re
import logging
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Optional
from datetime import datetime
from pydantic import BaseModel
import markdown
import PyPDF2

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("TxAgentAPI")

# Add src directory to Python path
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)

# Import TxAgent after setting up path
try:
    from txagent.txagent import TxAgent
except ImportError as e:
    logger.error(f"Failed to import TxAgent: {str(e)}")
    raise

# Initialize FastAPI app
app = FastAPI(
    title="TxAgent API",
    description="API for TxAgent medical document analysis",
    version="2.0.0"
)

# CORS configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Request models
class ChatRequest(BaseModel):
    message: str
    temperature: float = 0.7
    max_new_tokens: int = 512
    history: Optional[List[Dict]] = None
    format: Optional[str] = "clean"  # Options: raw, clean, structured, html

# Response cleaning functions
def clean_text_response(text: str) -> str:
    """Basic text cleaning"""
    text = re.sub(r'\n\s*\n', '\n\n', text)
    text = re.sub(r'[ ]+', ' ', text)
    text = text.replace("**", "").replace("__", "")
    return text.strip()

def structure_medical_response(text: str) -> Dict:
    """Structure medical content into categories"""
    result = {"overview": "", "symptoms": [], "types": {}, "notes": ""}
    overview_end = text.find("Type 1 Diabetes:")
    result["overview"] = clean_text_response(text[:overview_end])
    type_sections = re.split(r'(Type \d Diabetes:)', text[overview_end:])
    current_type = None
    for section in type_sections:
        if section.startswith("Type "):
            current_type = section.replace(":", "").strip().lower()
            result["types"][current_type] = []
        elif current_type:
            points = [p.strip() for p in section.split('\n') if p.strip()]
            result["types"][current_type] = points
    return result

# Initialize agent at startup
agent = None

@app.on_event("startup")
async def startup_event():
    global agent
    try:
        logger.info("Initializing TxAgent...")
        agent = TxAgent(
            model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
            rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
            tool_files_dict={},
            enable_finish=True,
            enable_rag=False,
            force_finish=True,
            enable_checker=True,
            step_rag_num=4,
            seed=100
        )
        agent.init_model()
        logger.info("TxAgent initialized successfully")
    except Exception as e:
        logger.error(f"Failed to initialize agent: {str(e)}")
        raise RuntimeError(f"Failed to initialize agent: {str(e)}")

@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
    """Handle chat conversations with formatting options"""
    try:
        logger.info(f"Chat request received (format: {request.format})")
        raw_response = agent.chat(
            message=request.message,
            history=request.history,
            temperature=request.temperature,
            max_new_tokens=request.max_new_tokens
        )
        formatted_response = {
            "raw": raw_response,
            "clean": clean_text_response(raw_response),
            "structured": structure_medical_response(raw_response),
            "html": markdown.markdown(raw_response)
        }
        response_content = formatted_response.get(request.format, formatted_response["clean"])
        return JSONResponse({
            "status": "success",
            "format": request.format,
            "response": response_content,
            "timestamp": datetime.now().isoformat(),
            "available_formats": list(formatted_response.keys())
        })
    except Exception as e:
        logger.error(f"Chat error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
    """Handle file uploads and process with TxAgent"""
    try:
        logger.info(f"File upload received: {file.filename}")
        content = ""
        if file.filename.endswith('.pdf'):
            pdf_reader = PyPDF2.PdfReader(file.file)
            for page in pdf_reader.pages:
                content += page.extract_text() or ""
        else:
            content = await file.read()
            content = content.decode('utf-8', errors='ignore')
        
        message = f"Analyze the following medical document content:\n\n{content[:10000]}"
        raw_response = agent.chat(
            message=message,
            history=[],
            temperature=0.7,
            max_new_tokens=512
        )
        formatted_response = {
            "raw": raw_response,
            "clean": clean_text_response(raw_response),
            "structured": structure_medical_response(raw_response),
            "html": markdown.markdown(raw_response)
        }
        response_content = formatted_response["clean"]
        return JSONResponse({
            "status": "success",
            "format": "clean",
            "response": response_content,
            "timestamp": datetime.now().isoformat(),
            "available_formats": list(formatted_response.keys())
        })
    except Exception as e:
        logger.error(f"File upload error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))
    finally:
        file.file.close()

@app.get("/status")
async def service_status():
    """Check service status"""
    return {
        "status": "running",
        "version": "2.0.0",
        "model": agent.model_name if agent else "not loaded",
        "formats_available": ["raw", "clean", "structured", "html"],
        "timestamp": datetime.now().isoformat()
    }