File size: 7,721 Bytes
f126604
d377221
f126604
 
7e095f4
7757822
 
f126604
d377221
f126604
d377221
7757822
f126604
7e095f4
 
 
 
 
 
 
 
 
 
 
520104a
 
 
 
 
7e095f4
 
 
 
 
 
520104a
d377221
f126604
7e095f4
f126604
 
 
 
 
32e4e6a
7e095f4
 
 
f126604
32e4e6a
f126604
 
 
7757822
 
 
 
 
 
 
 
 
 
 
 
 
32e4e6a
 
7757822
 
32e4e6a
 
f126604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e095f4
32e4e6a
520104a
32e4e6a
7e095f4
32e4e6a
f126604
d377221
7757822
7e095f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f126604
7757822
 
 
 
7e095f4
7757822
 
 
 
 
 
7e095f4
7757822
 
 
 
 
 
7e095f4
7757822
 
 
 
 
 
7e095f4
7757822
 
 
 
 
 
7e095f4
7757822
 
 
 
 
 
7e095f4
7757822
 
32e4e6a
f126604
7757822
f126604
7e095f4
 
f126604
 
 
 
7e095f4
f126604
7757822
 
 
7e095f4
7757822
 
7e095f4
 
7757822
 
 
 
 
7e095f4
 
32e4e6a
d377221
f126604
7e095f4
f126604
 
 
7757822
 
d377221
f126604
 
7e095f4
f126604
 
 
 
7757822
7e095f4
f126604
 
d377221
7e095f4
 
32e4e6a
7e095f4
 
f126604
 
7e095f4
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import os
import sys
import json
import shutil
import logging
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Optional
import torch
from datetime import datetime
from pydantic import BaseModel

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('txagent_api.log')
    ]
)
logger = logging.getLogger("TxAgentAPI")

# Add src directory to Python path
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)

# Import TxAgent after setting up path
try:
    from txagent.txagent import TxAgent
except ImportError as e:
    logger.error(f"Failed to import TxAgent: {str(e)}")
    raise

# Configuration
persistent_dir = "/data/hf_cache"
os.makedirs(persistent_dir, exist_ok=True)
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
file_cache_dir = os.path.join(persistent_dir, "cache")
report_dir = os.path.join(persistent_dir, "reports")

# Create directories if they don't exist
for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
    os.makedirs(directory, exist_ok=True)
    logger.info(f"Created directory: {directory}")

# Set environment variables
os.environ["HF_HOME"] = model_cache_dir
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir

# Request models
class ChatRequest(BaseModel):
    message: str
    temperature: float = 0.7
    max_new_tokens: int = 512
    history: Optional[List[Dict]] = None

class MultistepRequest(BaseModel):
    message: str
    temperature: float = 0.7
    max_new_tokens: int = 512
    max_round: int = 5

# Initialize FastAPI app
app = FastAPI(
    title="TxAgent API",
    description="API for TxAgent medical document analysis",
    version="1.0.0"
)

# CORS configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize agent at startup
agent = None

@app.on_event("startup")
async def startup_event():
    global agent
    try:
        logger.info("Initializing TxAgent...")
        agent = init_agent()
        logger.info("TxAgent initialized successfully")
    except Exception as e:
        logger.error(f"Failed to initialize agent: {str(e)}", exc_info=True)
        raise RuntimeError(f"Failed to initialize agent: {str(e)}")

def init_agent():
    """Initialize and return the TxAgent instance"""
    try:
        tool_path = os.path.join(tool_cache_dir, "new_tool.json")
        if not os.path.exists(tool_path):
            logger.info(f"Copying tool file to {tool_path}")
            default_tool = os.path.abspath("data/new_tool.json")
            if os.path.exists(default_tool):
                shutil.copy(default_tool, tool_path)
            else:
                raise FileNotFoundError(f"Default tool file not found at {default_tool}")
        
        logger.info("Creating TxAgent instance")
        agent = TxAgent(
            model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
            rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
            tool_files_dict={"new_tool": tool_path},
            enable_finish=True,
            enable_rag=False,
            force_finish=True,
            enable_checker=True,
            step_rag_num=4,
            seed=100
        )
        agent.init_model()
        return agent
    except Exception as e:
        logger.error(f"Error in init_agent: {str(e)}", exc_info=True)
        raise

@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
    """Handle chat conversations"""
    try:
        logger.info(f"Chat request received: {request.message[:50]}...")
        response = agent.chat(
            message=request.message,
            history=request.history,
            temperature=request.temperature,
            max_new_tokens=request.max_new_tokens
        )
        logger.info("Chat response generated successfully")
        return JSONResponse({
            "status": "success",
            "response": response,
            "timestamp": datetime.now().isoformat()
        })
    except Exception as e:
        logger.error(f"Chat error: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/multistep")
async def multistep_endpoint(request: MultistepRequest):
    """Run multi-step reasoning"""
    try:
        logger.info(f"Multistep request received: {request.message[:50]}...")
        response = agent.run_multistep_agent(
            message=request.message,
            temperature=request.temperature,
            max_new_tokens=request.max_new_tokens,
            max_round=request.max_round
        )
        logger.info("Multistep reasoning completed successfully")
        return JSONResponse({
            "status": "success",
            "response": response,
            "timestamp": datetime.now().isoformat()
        })
    except Exception as e:
        logger.error(f"Multistep error: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/analyze")
async def analyze_document(file: UploadFile = File(...)):
    """Analyze a medical document"""
    try:
        logger.info(f"Document analysis request received for: {file.filename}")
        
        # Save the uploaded file temporarily
        temp_path = os.path.join(file_cache_dir, file.filename)
        with open(temp_path, "wb") as f:
            f.write(await file.read())
        logger.info(f"File saved temporarily at {temp_path}")
        
        # Process the document
        text = agent.extract_text_from_file(temp_path)
        analysis = agent.analyze_text(text)
        logger.info("Document analysis completed successfully")
        
        # Generate report
        report_filename = f"{os.path.splitext(file.filename)[0]}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        report_path = os.path.join(report_dir, report_filename)
        with open(report_path, "w") as f:
            json.dump({
                "filename": file.filename,
                "analysis": analysis,
                "timestamp": datetime.now().isoformat()
            }, f, indent=2)
        logger.info(f"Report generated at {report_path}")
        
        # Clean up
        os.remove(temp_path)
        logger.info(f"Temporary file {temp_path} removed")
        
        return JSONResponse({
            "status": "success",
            "analysis": analysis,
            "report_path": report_path,
            "timestamp": datetime.now().isoformat()
        })
    except Exception as e:
        logger.error(f"Document analysis error: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/status")
async def service_status():
    """Check service status"""
    status = {
        "status": "running",
        "version": "1.0.0",
        "model": agent.model_name if agent else "not loaded",
        "device": str(agent.device) if agent else "unknown",
        "timestamp": datetime.now().isoformat()
    }
    logger.info(f"Status check: {status}")
    return status

if __name__ == "__main__":
    try:
        logger.info("Starting TxAgent API server")
        import uvicorn
        uvicorn.run(app, host="0.0.0.0", port=8000, log_config=None)
    except Exception as e:
        logger.error(f"Failed to start server: {str(e)}", exc_info=True)
        raise