Ali2206 commited on
Commit
61c414a
·
verified ·
1 Parent(s): 6b4270f

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +173 -158
src/txagent/txagent.py CHANGED
@@ -1,163 +1,178 @@
 
1
  import os
2
- import logging
 
 
 
 
 
 
3
  import torch
4
- import pdfplumber
5
- import pandas as pd
6
- from typing import Dict, Optional, Union
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
- from sentence_transformers import SentenceTransformer
9
-
10
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
11
- logger = logging.getLogger("TxAgent")
12
-
13
- class TxAgent:
14
- def __init__(self,
15
- model_name: str,
16
- rag_model_name: str,
17
- tool_files_dict: Optional[Dict] = None,
18
- force_finish: bool = True,
19
- enable_checker: bool = True,
20
- step_rag_num: int = 4,
21
- seed: Optional[int] = None):
22
- """Initialize TxAgent without vLLM dependencies."""
23
- self.model_name = model_name
24
- self.rag_model_name = rag_model_name
25
- self.tool_files_dict = tool_files_dict or {}
26
- self.force_finish = force_finish
27
- self.enable_checker = enable_checker
28
- self.step_rag_num = step_rag_num
29
- self.seed = seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
- self.model = None
33
- self.tokenizer = None
34
- self.rag_model = None
35
 
36
- logger.info(f"Initialized TxAgent with model: {model_name} on device: {self.device}")
37
-
38
- def init_model(self):
39
- """Initialize models using transformers only."""
40
- self.load_llm_model()
41
- self.load_rag_model()
42
- logger.info("Model initialization complete")
43
-
44
- def load_llm_model(self):
45
- """Load the main LLM model using transformers."""
46
- try:
47
- logger.info(f"Loading LLM model: {self.model_name}")
48
- self.tokenizer = AutoTokenizer.from_pretrained(
49
- self.model_name,
50
- cache_dir=os.getenv("HF_HOME")
51
- )
52
- self.model = AutoModelForCausalLM.from_pretrained(
53
- self.model_name,
54
- torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
55
- device_map="auto",
56
- cache_dir=os.getenv("HF_HOME")
57
- )
58
- logger.info(f"LLM model loaded on {self.device}")
59
- except Exception as e:
60
- logger.error(f"Failed to load LLM model: {str(e)}")
61
- raise RuntimeError(f"Failed to load LLM model: {str(e)}")
62
-
63
- def load_rag_model(self):
64
- """Load the RAG model."""
65
- try:
66
- logger.info(f"Loading RAG model: {self.rag_model_name}")
67
- self.rag_model = SentenceTransformer(
68
- self.rag_model_name,
69
- device=str(self.device)
70
- )
71
- logger.info("RAG model loaded successfully")
72
- except Exception as e:
73
- logger.error(f"Failed to load RAG model: {str(e)}")
74
- raise RuntimeError(f"Failed to load RAG model: {str(e)}")
75
-
76
- def process_document(self, file_path: str) -> Dict[str, Union[str, Dict]]:
77
- """Process a document and return real analysis results."""
78
- try:
79
- text = self.extract_text_from_file(file_path)
80
- if not text:
81
- return {
82
- "status": "error",
83
- "message": "Failed to extract text",
84
- "model": self.model_name
85
- }
86
-
87
- analysis = self.analyze_text(text)
88
-
89
- return {
90
- "status": "success",
91
  "analysis": analysis,
92
- "model": self.model_name
93
- }
94
-
95
- except Exception as e:
96
- logger.error(f"Document processing failed: {str(e)}")
97
- return {
98
- "status": "error",
99
- "message": str(e),
100
- "model": self.model_name
101
- }
102
-
103
- def extract_text_from_file(self, file_path: str) -> Optional[str]:
104
- """Extract text from PDF, CSV, or Excel files."""
105
- try:
106
- if file_path.endswith('.pdf'):
107
- with pdfplumber.open(file_path) as pdf:
108
- return "\n".join(
109
- page.extract_text()
110
- for page in pdf.pages
111
- if page.extract_text()
112
- )
113
-
114
- elif file_path.endswith('.csv'):
115
- df = pd.read_csv(file_path)
116
- return df.to_string()
117
-
118
- elif file_path.endswith(('.xlsx', '.xls')):
119
- df = pd.read_excel(file_path)
120
- return df.to_string()
121
-
122
- logger.warning(f"Unsupported file type: {file_path}")
123
- return None
124
-
125
- except Exception as e:
126
- logger.error(f"Text extraction failed: {str(e)}")
127
- raise RuntimeError(f"Text extraction failed: {str(e)}")
128
-
129
- def analyze_text(self, text: str, max_tokens: int = 1000) -> str:
130
- """Analyze extracted text using the LLM."""
131
- try:
132
- prompt = f"""Analyze this medical document:
133
- 1. Diagnostic patterns
134
- 2. Medication issues
135
- 3. Recommended follow-ups
136
-
137
- Document:
138
- {text[:8000]} # Truncate to avoid token limits
139
- """
140
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
141
- outputs = self.model.generate(
142
- **inputs,
143
- max_new_tokens=max_tokens,
144
- pad_token_id=self.tokenizer.eos_token_id
145
- )
146
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
147
-
148
- except Exception as e:
149
- logger.error(f"Text analysis failed: {str(e)}")
150
- raise RuntimeError(f"Analysis failed: {str(e)}")
151
-
152
- def cleanup(self):
153
- """Clean up resources."""
154
- if hasattr(self, 'model'):
155
- del self.model
156
- if hasattr(self, 'rag_model'):
157
- del self.rag_model
158
- torch.cuda.empty_cache()
159
- logger.info("TxAgent resources cleaned up")
160
-
161
- def __del__(self):
162
- """Destructor to ensure proper cleanup."""
163
- self.cleanup()
 
1
+ # app.py - FastAPI application
2
  import os
3
+ import sys
4
+ import json
5
+ import shutil
6
+ from fastapi import FastAPI, HTTPException, UploadFile, File
7
+ from fastapi.responses import JSONResponse
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from typing import List, Dict, Optional
10
  import torch
11
+ from datetime import datetime
12
+ from pydantic import BaseModel
13
+
14
+ # Configuration
15
+ persistent_dir = "/data/hf_cache"
16
+ model_cache_dir = os.path.join(persistent_dir, "txagent_models")
17
+ tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
18
+ file_cache_dir = os.path.join(persistent_dir, "cache")
19
+ report_dir = os.path.join(persistent_dir, "reports")
20
+
21
+ # Create directories if they don't exist
22
+ os.makedirs(model_cache_dir, exist_ok=True)
23
+ os.makedirs(tool_cache_dir, exist_ok=True)
24
+ os.makedirs(file_cache_dir, exist_ok=True)
25
+ os.makedirs(report_dir, exist_ok=True)
26
+
27
+ # Set environment variables
28
+ os.environ["HF_HOME"] = model_cache_dir
29
+ os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
30
+
31
+ # Set up Python path
32
+ current_dir = os.path.dirname(os.path.abspath(__file__))
33
+ src_path = os.path.abspath(os.path.join(current_dir, "src"))
34
+ sys.path.insert(0, src_path)
35
+
36
+ # Request models
37
+ class ChatRequest(BaseModel):
38
+ message: str
39
+ temperature: float = 0.7
40
+ max_new_tokens: int = 512
41
+ history: Optional[List[Dict]] = None
42
+
43
+ class MultistepRequest(BaseModel):
44
+ message: str
45
+ temperature: float = 0.7
46
+ max_new_tokens: int = 512
47
+ max_round: int = 5
48
+
49
+ # Initialize FastAPI app
50
+ app = FastAPI(
51
+ title="TxAgent API",
52
+ description="API for TxAgent medical document analysis",
53
+ version="1.0.0"
54
+ )
55
+
56
+ # CORS configuration
57
+ app.add_middleware(
58
+ CORSMiddleware,
59
+ allow_origins=["*"],
60
+ allow_credentials=True,
61
+ allow_methods=["*"],
62
+ allow_headers=["*"],
63
+ )
64
+
65
+ # Initialize agent at startup
66
+ agent = None
67
+
68
+ @app.on_event("startup")
69
+ async def startup_event():
70
+ global agent
71
+ try:
72
+ agent = init_agent()
73
+ except Exception as e:
74
+ raise RuntimeError(f"Failed to initialize agent: {str(e)}")
75
+
76
+ def init_agent():
77
+ """Initialize and return the TxAgent instance"""
78
+ tool_path = os.path.join(tool_cache_dir, "new_tool.json")
79
+ if not os.path.exists(tool_path):
80
+ shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
81
+
82
+ agent = TxAgent(
83
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
84
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
85
+ tool_files_dict={"new_tool": tool_path},
86
+ enable_finish=True,
87
+ enable_rag=False,
88
+ force_finish=True,
89
+ enable_checker=True,
90
+ step_rag_num=4,
91
+ seed=100
92
+ )
93
+ agent.init_model()
94
+ return agent
95
+
96
+ @app.post("/chat")
97
+ async def chat_endpoint(request: ChatRequest):
98
+ """Handle chat conversations"""
99
+ try:
100
+ response = agent.chat(
101
+ message=request.message,
102
+ history=request.history,
103
+ temperature=request.temperature,
104
+ max_new_tokens=request.max_new_tokens
105
+ )
106
+ return JSONResponse({
107
+ "status": "success",
108
+ "response": response,
109
+ "timestamp": datetime.now().isoformat()
110
+ })
111
+ except Exception as e:
112
+ raise HTTPException(status_code=500, detail=str(e))
113
+
114
+ @app.post("/multistep")
115
+ async def multistep_endpoint(request: MultistepRequest):
116
+ """Run multi-step reasoning"""
117
+ try:
118
+ response = agent.run_multistep_agent(
119
+ message=request.message,
120
+ temperature=request.temperature,
121
+ max_new_tokens=request.max_new_tokens,
122
+ max_round=request.max_round
123
+ )
124
+ return JSONResponse({
125
+ "status": "success",
126
+ "response": response,
127
+ "timestamp": datetime.now().isoformat()
128
+ })
129
+ except Exception as e:
130
+ raise HTTPException(status_code=500, detail=str(e))
131
+
132
+ @app.post("/analyze")
133
+ async def analyze_document(file: UploadFile = File(...)):
134
+ """Analyze a medical document"""
135
+ try:
136
+ # Save the uploaded file temporarily
137
+ temp_path = os.path.join(file_cache_dir, file.filename)
138
+ with open(temp_path, "wb") as f:
139
+ f.write(await file.read())
140
 
141
+ # Process the document
142
+ text = agent.extract_text_from_file(temp_path)
143
+ analysis = agent.analyze_text(text)
 
144
 
145
+ # Generate report
146
+ report_path = os.path.join(report_dir, f"{file.filename}.json")
147
+ with open(report_path, "w") as f:
148
+ json.dump({
149
+ "filename": file.filename,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  "analysis": analysis,
151
+ "timestamp": datetime.now().isoformat()
152
+ }, f)
153
+
154
+ # Clean up
155
+ os.remove(temp_path)
156
+
157
+ return JSONResponse({
158
+ "status": "success",
159
+ "analysis": analysis,
160
+ "report_path": report_path,
161
+ "timestamp": datetime.now().isoformat()
162
+ })
163
+ except Exception as e:
164
+ raise HTTPException(status_code=500, detail=str(e))
165
+
166
+ @app.get("/status")
167
+ async def service_status():
168
+ """Check service status"""
169
+ return {
170
+ "status": "running",
171
+ "version": "1.0.0",
172
+ "model": agent.model_name if agent else "not loaded",
173
+ "device": str(agent.device) if agent else "unknown"
174
+ }
175
+
176
+ if __name__ == "__main__":
177
+ import uvicorn
178
+ uvicorn.run(app, host="0.0.0.0", port=8000)