Ali2206 commited on
Commit
fd39839
·
verified ·
1 Parent(s): 7757822

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +248 -165
src/txagent/txagent.py CHANGED
@@ -1,178 +1,261 @@
1
- # app.py - FastAPI application
2
  import os
3
- import sys
4
- import json
5
- import shutil
6
- from fastapi import FastAPI, HTTPException, UploadFile, File
7
- from fastapi.responses import JSONResponse
8
- from fastapi.middleware.cors import CORSMiddleware
9
- from typing import List, Dict, Optional
10
  import torch
11
- from datetime import datetime
12
- from pydantic import BaseModel
13
-
14
- # Configuration
15
- persistent_dir = "/data/hf_cache"
16
- model_cache_dir = os.path.join(persistent_dir, "txagent_models")
17
- tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
18
- file_cache_dir = os.path.join(persistent_dir, "cache")
19
- report_dir = os.path.join(persistent_dir, "reports")
20
-
21
- # Create directories if they don't exist
22
- os.makedirs(model_cache_dir, exist_ok=True)
23
- os.makedirs(tool_cache_dir, exist_ok=True)
24
- os.makedirs(file_cache_dir, exist_ok=True)
25
- os.makedirs(report_dir, exist_ok=True)
26
-
27
- # Set environment variables
28
- os.environ["HF_HOME"] = model_cache_dir
29
- os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
30
-
31
- # Set up Python path
32
- current_dir = os.path.dirname(os.path.abspath(__file__))
33
- src_path = os.path.abspath(os.path.join(current_dir, "src"))
34
- sys.path.insert(0, src_path)
35
-
36
- # Request models
37
- class ChatRequest(BaseModel):
38
- message: str
39
- temperature: float = 0.7
40
- max_new_tokens: int = 512
41
- history: Optional[List[Dict]] = None
42
 
43
- class MultistepRequest(BaseModel):
44
- message: str
45
- temperature: float = 0.7
46
- max_new_tokens: int = 512
47
- max_round: int = 5
48
 
49
- # Initialize FastAPI app
50
- app = FastAPI(
51
- title="TxAgent API",
52
- description="API for TxAgent medical document analysis",
53
- version="1.0.0"
54
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- # CORS configuration
57
- app.add_middleware(
58
- CORSMiddleware,
59
- allow_origins=["*"],
60
- allow_credentials=True,
61
- allow_methods=["*"],
62
- allow_headers=["*"],
63
- )
64
 
65
- # Initialize agent at startup
66
- agent = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- @app.on_event("startup")
69
- async def startup_event():
70
- global agent
71
- try:
72
- agent = init_agent()
73
- except Exception as e:
74
- raise RuntimeError(f"Failed to initialize agent: {str(e)}")
 
 
75
 
76
- def init_agent():
77
- """Initialize and return the TxAgent instance"""
78
- tool_path = os.path.join(tool_cache_dir, "new_tool.json")
79
- if not os.path.exists(tool_path):
80
- shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
81
-
82
- agent = TxAgent(
83
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
84
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
85
- tool_files_dict={"new_tool": tool_path},
86
- enable_finish=True,
87
- enable_rag=False,
88
- force_finish=True,
89
- enable_checker=True,
90
- step_rag_num=4,
91
- seed=100
92
- )
93
- agent.init_model()
94
- return agent
95
 
96
- @app.post("/chat")
97
- async def chat_endpoint(request: ChatRequest):
98
- """Handle chat conversations"""
99
- try:
100
- response = agent.chat(
101
- message=request.message,
102
- history=request.history,
103
- temperature=request.temperature,
104
- max_new_tokens=request.max_new_tokens
105
- )
106
- return JSONResponse({
107
- "status": "success",
108
- "response": response,
109
- "timestamp": datetime.now().isoformat()
110
- })
111
- except Exception as e:
112
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- @app.post("/multistep")
115
- async def multistep_endpoint(request: MultistepRequest):
116
- """Run multi-step reasoning"""
117
- try:
118
- response = agent.run_multistep_agent(
119
- message=request.message,
120
- temperature=request.temperature,
121
- max_new_tokens=request.max_new_tokens,
122
- max_round=request.max_round
123
- )
124
- return JSONResponse({
125
- "status": "success",
126
- "response": response,
127
- "timestamp": datetime.now().isoformat()
128
- })
129
- except Exception as e:
130
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- @app.post("/analyze")
133
- async def analyze_document(file: UploadFile = File(...)):
134
- """Analyze a medical document"""
135
- try:
136
- # Save the uploaded file temporarily
137
- temp_path = os.path.join(file_cache_dir, file.filename)
138
- with open(temp_path, "wb") as f:
139
- f.write(await file.read())
140
-
141
- # Process the document
142
- text = agent.extract_text_from_file(temp_path)
143
- analysis = agent.analyze_text(text)
144
-
145
- # Generate report
146
- report_path = os.path.join(report_dir, f"{file.filename}.json")
147
- with open(report_path, "w") as f:
148
- json.dump({
149
- "filename": file.filename,
150
- "analysis": analysis,
151
- "timestamp": datetime.now().isoformat()
152
- }, f)
153
-
154
- # Clean up
155
- os.remove(temp_path)
156
-
157
- return JSONResponse({
158
- "status": "success",
159
- "analysis": analysis,
160
- "report_path": report_path,
161
- "timestamp": datetime.now().isoformat()
162
- })
163
- except Exception as e:
164
- raise HTTPException(status_code=500, detail=str(e))
165
 
166
- @app.get("/status")
167
- async def service_status():
168
- """Check service status"""
169
- return {
170
- "status": "running",
171
- "version": "1.0.0",
172
- "model": agent.model_name if agent else "not loaded",
173
- "device": str(agent.device) if agent else "unknown"
174
- }
175
 
176
- if __name__ == "__main__":
177
- import uvicorn
178
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ # txagent.py - Core TxAgent class (simplified but maintains key functionality)
2
  import os
3
+ import logging
 
 
 
 
 
 
4
  import torch
5
+ import json
6
+ from typing import Dict, Optional, List, Union
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
8
+ from sentence_transformers import SentenceTransformer
9
+ from tooluniverse import ToolUniverse
10
+ from .toolrag import ToolRAGModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
13
+ logger = logging.getLogger("TxAgent")
 
 
 
14
 
15
+ class TxAgent:
16
+ def __init__(self,
17
+ model_name: str,
18
+ rag_model_name: str,
19
+ tool_files_dict: Optional[Dict] = None,
20
+ enable_finish: bool = True,
21
+ enable_rag: bool = False,
22
+ enable_summary: bool = False,
23
+ init_rag_num: int = 0,
24
+ step_rag_num: int = 0,
25
+ summary_mode: str = 'step',
26
+ summary_skip_last_k: int = 0,
27
+ summary_context_length: Optional[int] = None,
28
+ force_finish: bool = True,
29
+ avoid_repeat: bool = True,
30
+ seed: Optional[int] = None,
31
+ enable_checker: bool = False,
32
+ enable_chat: bool = False,
33
+ additional_default_tools: Optional[List] = None):
34
+
35
+ # Initialization parameters
36
+ self.model_name = model_name
37
+ self.rag_model_name = rag_model_name
38
+ self.tool_files_dict = tool_files_dict or {}
39
+ self.enable_finish = enable_finish
40
+ self.enable_rag = enable_rag
41
+ self.enable_summary = enable_summary
42
+ self.summary_mode = summary_mode
43
+ self.summary_skip_last_k = summary_skip_last_k
44
+ self.summary_context_length = summary_context_length
45
+ self.init_rag_num = init_rag_num
46
+ self.step_rag_num = step_rag_num
47
+ self.force_finish = force_finish
48
+ self.avoid_repeat = avoid_repeat
49
+ self.seed = seed
50
+ self.enable_checker = enable_checker
51
+ self.enable_chat = enable_chat
52
+ self.additional_default_tools = additional_default_tools or []
53
+
54
+ # Device setup
55
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+
57
+ # Models
58
+ self.model = None
59
+ self.tokenizer = None
60
+ self.rag_model = None
61
+ self.tooluniverse = None
62
+
63
+ # Prompts
64
+ self.prompt_multi_step = "You are a helpful assistant that solves problems through step-by-step reasoning."
65
+ self.self_prompt = "Strictly follow the instruction."
66
+ self.chat_prompt = "You are a helpful assistant for user chat."
67
+
68
+ logger.info(f"Initialized TxAgent with model: {model_name} on device: {self.device}")
69
 
70
+ def init_model(self):
71
+ """Initialize all models and components"""
72
+ self.load_llm_model()
73
+ self.load_rag_model()
74
+ self.load_tooluniverse()
75
+ logger.info("All models initialized successfully")
 
 
76
 
77
+ def load_llm_model(self):
78
+ """Load the main LLM model"""
79
+ try:
80
+ logger.info(f"Loading LLM model: {self.model_name}")
81
+ self.tokenizer = AutoTokenizer.from_pretrained(
82
+ self.model_name,
83
+ cache_dir=os.getenv("HF_HOME"),
84
+ trust_remote_code=True
85
+ )
86
+
87
+ self.model = AutoModelForCausalLM.from_pretrained(
88
+ self.model_name,
89
+ torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
90
+ device_map="auto",
91
+ cache_dir=os.getenv("HF_HOME"),
92
+ trust_remote_code=True
93
+ )
94
+ logger.info(f"LLM model loaded on {self.device}")
95
+ except Exception as e:
96
+ logger.error(f"Failed to load LLM model: {str(e)}")
97
+ raise
98
 
99
+ def load_rag_model(self):
100
+ """Load the RAG model"""
101
+ try:
102
+ logger.info(f"Loading RAG model: {self.rag_model_name}")
103
+ self.rag_model = ToolRAGModel(self.rag_model_name)
104
+ logger.info("RAG model loaded successfully")
105
+ except Exception as e:
106
+ logger.error(f"Failed to load RAG model: {str(e)}")
107
+ raise
108
 
109
+ def load_tooluniverse(self):
110
+ """Initialize the ToolUniverse"""
111
+ try:
112
+ logger.info("Loading ToolUniverse")
113
+ self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
114
+ self.tooluniverse.load_tools()
115
+
116
+ # Prepare special tools
117
+ special_tools = self.tooluniverse.prepare_tool_prompts(
118
+ self.tooluniverse.tool_category_dicts["special_tools"])
119
+ self.special_tools_name = [tool['name'] for tool in special_tools]
120
+
121
+ logger.info(f"ToolUniverse loaded with {len(self.special_tools_name)} special tools")
122
+ except Exception as e:
123
+ logger.error(f"Failed to load ToolUniverse: {str(e)}")
124
+ raise
 
 
 
125
 
126
+ def chat(self, message: str, history: Optional[List[Dict]] = None,
127
+ temperature: float = 0.7, max_new_tokens: int = 512) -> str:
128
+ """Handle chat conversations"""
129
+ try:
130
+ conversation = []
131
+
132
+ # Initialize with system prompt
133
+ conversation.append({"role": "system", "content": self.chat_prompt})
134
+
135
+ # Add history if provided
136
+ if history:
137
+ for msg in history:
138
+ conversation.append({"role": msg["role"], "content": msg["content"]})
139
+
140
+ # Add current message
141
+ conversation.append({"role": "user", "content": message})
142
+
143
+ # Generate response
144
+ inputs = self.tokenizer.apply_chat_template(
145
+ conversation,
146
+ add_generation_prompt=True,
147
+ return_tensors="pt"
148
+ ).to(self.device)
149
+
150
+ generation_config = GenerationConfig(
151
+ max_new_tokens=max_new_tokens,
152
+ temperature=temperature,
153
+ do_sample=True,
154
+ pad_token_id=self.tokenizer.eos_token_id
155
+ )
156
+
157
+ outputs = self.model.generate(
158
+ inputs,
159
+ generation_config=generation_config
160
+ )
161
+
162
+ # Decode and clean up response
163
+ response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
164
+ return response.strip()
165
+
166
+ except Exception as e:
167
+ logger.error(f"Chat failed: {str(e)}")
168
+ raise RuntimeError(f"Chat failed: {str(e)}")
169
 
170
+ def run_multistep_agent(self, message: str, temperature: float = 0.7,
171
+ max_new_tokens: int = 512, max_round: int = 5) -> str:
172
+ """Run multi-step reasoning agent"""
173
+ try:
174
+ conversation = [{"role": "system", "content": self.prompt_multi_step}]
175
+ conversation.append({"role": "user", "content": message})
176
+
177
+ for _ in range(max_round):
178
+ # Generate next step
179
+ inputs = self.tokenizer.apply_chat_template(
180
+ conversation,
181
+ add_generation_prompt=True,
182
+ return_tensors="pt"
183
+ ).to(self.device)
184
+
185
+ generation_config = GenerationConfig(
186
+ max_new_tokens=max_new_tokens,
187
+ temperature=temperature,
188
+ do_sample=True,
189
+ pad_token_id=self.tokenizer.eos_token_id
190
+ )
191
+
192
+ outputs = self.model.generate(
193
+ inputs,
194
+ generation_config=generation_config
195
+ )
196
+
197
+ response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
198
+
199
+ # Check for final answer
200
+ if "[FinalAnswer]" in response:
201
+ return response.split("[FinalAnswer]")[-1].strip()
202
+
203
+ # Add to conversation
204
+ conversation.append({"role": "assistant", "content": response})
205
+
206
+ # If max rounds reached
207
+ if self.force_finish:
208
+ return self._force_final_answer(conversation, temperature, max_new_tokens)
209
+
210
+ return "Reasoning rounds exceeded limit without reaching a final answer."
211
+
212
+ except Exception as e:
213
+ logger.error(f"Multi-step agent failed: {str(e)}")
214
+ raise RuntimeError(f"Multi-step agent failed: {str(e)}")
215
 
216
+ def _force_final_answer(self, conversation: List[Dict], temperature: float, max_new_tokens: int) -> str:
217
+ """Force a final answer when max rounds reached"""
218
+ try:
219
+ # Add instruction to provide final answer
220
+ conversation.append({
221
+ "role": "user",
222
+ "content": "Provide your final answer now based on all previous reasoning."
223
+ })
224
+
225
+ inputs = self.tokenizer.apply_chat_template(
226
+ conversation,
227
+ add_generation_prompt=True,
228
+ return_tensors="pt"
229
+ ).to(self.device)
230
+
231
+ generation_config = GenerationConfig(
232
+ max_new_tokens=max_new_tokens,
233
+ temperature=temperature,
234
+ do_sample=True,
235
+ pad_token_id=self.tokenizer.eos_token_id
236
+ )
237
+
238
+ outputs = self.model.generate(
239
+ inputs,
240
+ generation_config=generation_config
241
+ )
242
+
243
+ response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
244
+ return response.strip()
245
+
246
+ except Exception as e:
247
+ logger.error(f"Failed to force final answer: {str(e)}")
248
+ return "Failed to generate final answer."
249
 
250
+ def cleanup(self):
251
+ """Clean up resources"""
252
+ if hasattr(self, 'model'):
253
+ del self.model
254
+ if hasattr(self, 'rag_model'):
255
+ del self.rag_model
256
+ torch.cuda.empty_cache()
257
+ logger.info("TxAgent resources cleaned up")
 
258
 
259
+ def __del__(self):
260
+ """Destructor to ensure proper cleanup"""
261
+ self.cleanup()