File size: 3,096 Bytes
f126604
d377221
f126604
7e095f4
7757822
 
f126604
d377221
 
7757822
f126604
bdcc052
7e095f4
 
bdcc052
7e095f4
 
 
520104a
 
 
 
 
7e095f4
 
 
 
 
 
520104a
32e4e6a
 
7757822
 
32e4e6a
 
f126604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e095f4
 
 
 
bdcc052
7e095f4
 
 
 
 
 
 
 
bdcc052
7e095f4
bdcc052
 
 
 
 
 
 
 
f126604
7757822
 
 
 
7e095f4
7757822
 
 
 
 
 
 
 
 
 
 
 
bdcc052
f126604
 
 
 
7757822
bdcc052
f126604
 
d377221
7e095f4
bdcc052
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import sys
import json
import logging
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Optional
from datetime import datetime
from pydantic import BaseModel

# Configure logging for Hugging Face Spaces
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("TxAgentAPI")

# Add src directory to Python path
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)

# Import TxAgent after setting up path
try:
    from txagent.txagent import TxAgent
except ImportError as e:
    logger.error(f"Failed to import TxAgent: {str(e)}")
    raise

# Initialize FastAPI app
app = FastAPI(
    title="TxAgent API",
    description="API for TxAgent medical document analysis",
    version="1.0.0"
)

# CORS configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize agent at startup
agent = None

@app.on_event("startup")
async def startup_event():
    global agent
    try:
        logger.info("Initializing TxAgent...")
        agent = TxAgent(
            model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
            rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
            tool_files_dict={},  # No tool files in this example
            enable_finish=True,
            enable_rag=False,
            force_finish=True,
            enable_checker=True,
            step_rag_num=4,
            seed=100
        )
        agent.init_model()
        logger.info("TxAgent initialized successfully")
    except Exception as e:
        logger.error(f"Failed to initialize agent: {str(e)}")
        raise RuntimeError(f"Failed to initialize agent: {str(e)}")

class ChatRequest(BaseModel):
    message: str
    temperature: float = 0.7
    max_new_tokens: int = 512
    history: Optional[List[Dict]] = None

@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
    """Handle chat conversations"""
    try:
        logger.info(f"Chat request received: {request.message[:50]}...")
        response = agent.chat(
            message=request.message,
            history=request.history,
            temperature=request.temperature,
            max_new_tokens=request.max_new_tokens
        )
        return JSONResponse({
            "status": "success",
            "response": response,
            "timestamp": datetime.now().isoformat()
        })
    except Exception as e:
        logger.error(f"Chat error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/status")
async def service_status():
    """Check service status"""
    return {
        "status": "running",
        "version": "1.0.0",
        "model": agent.model_name if agent else "not loaded",
        "timestamp": datetime.now().isoformat()
    }