Ali2206 commited on
Commit
adac5ab
·
verified ·
1 Parent(s): bdcc052

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +21 -223
src/txagent/txagent.py CHANGED
@@ -1,22 +1,14 @@
1
  import os
2
  import logging
3
  import torch
4
- import pdfplumber
5
- import pandas as pd
6
  from typing import Dict, Optional, List, Union
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
8
  from sentence_transformers import SentenceTransformer
9
- from tooluniverse import ToolUniverse
10
- from .toolrag import ToolRAGModel
11
 
12
- # Configure logging
13
  logging.basicConfig(
14
  level=logging.INFO,
15
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
16
- handlers=[
17
- logging.StreamHandler(),
18
- logging.FileHandler('txagent_core.log')
19
- ]
20
  )
21
  logger = logging.getLogger("TxAgent")
22
 
@@ -27,18 +19,10 @@ class TxAgent:
27
  tool_files_dict: Optional[Dict] = None,
28
  enable_finish: bool = True,
29
  enable_rag: bool = False,
30
- enable_summary: bool = False,
31
- init_rag_num: int = 0,
32
- step_rag_num: int = 0,
33
- summary_mode: str = 'step',
34
- summary_skip_last_k: int = 0,
35
- summary_context_length: Optional[int] = None,
36
  force_finish: bool = True,
37
- avoid_repeat: bool = True,
38
- seed: Optional[int] = None,
39
- enable_checker: bool = False,
40
- enable_chat: bool = False,
41
- additional_default_tools: Optional[List] = None):
42
 
43
  # Initialization parameters
44
  self.model_name = model_name
@@ -46,18 +30,10 @@ class TxAgent:
46
  self.tool_files_dict = tool_files_dict or {}
47
  self.enable_finish = enable_finish
48
  self.enable_rag = enable_rag
49
- self.enable_summary = enable_summary
50
- self.summary_mode = summary_mode
51
- self.summary_skip_last_k = summary_skip_last_k
52
- self.summary_context_length = summary_context_length
53
- self.init_rag_num = init_rag_num
54
- self.step_rag_num = step_rag_num
55
  self.force_finish = force_finish
56
- self.avoid_repeat = avoid_repeat
57
- self.seed = seed
58
  self.enable_checker = enable_checker
59
- self.enable_chat = enable_chat
60
- self.additional_default_tools = additional_default_tools or []
61
 
62
  # Device setup
63
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -66,24 +42,21 @@ class TxAgent:
66
  self.model = None
67
  self.tokenizer = None
68
  self.rag_model = None
69
- self.tooluniverse = None
70
 
71
  # Prompts
72
- self.prompt_multi_step = "You are a helpful assistant that solves problems through step-by-step reasoning."
73
- self.self_prompt = "Strictly follow the instruction."
74
  self.chat_prompt = "You are a helpful assistant for user chat."
75
 
76
- logger.info(f"Initialized TxAgent with model: {model_name} on device: {self.device}")
77
 
78
  def init_model(self):
79
  """Initialize all models and components"""
80
  try:
81
  self.load_llm_model()
82
- self.load_rag_model()
83
- self.load_tooluniverse()
84
- logger.info("All models initialized successfully")
85
  except Exception as e:
86
- logger.error(f"Model initialization failed: {str(e)}", exc_info=True)
87
  raise
88
 
89
  def load_llm_model(self):
@@ -92,7 +65,6 @@ class TxAgent:
92
  logger.info(f"Loading LLM model: {self.model_name}")
93
  self.tokenizer = AutoTokenizer.from_pretrained(
94
  self.model_name,
95
- cache_dir=os.getenv("HF_HOME"),
96
  trust_remote_code=True
97
  )
98
 
@@ -100,39 +72,24 @@ class TxAgent:
100
  self.model_name,
101
  torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
102
  device_map="auto",
103
- cache_dir=os.getenv("HF_HOME"),
104
  trust_remote_code=True
105
  )
106
  logger.info(f"LLM model loaded on {self.device}")
107
  except Exception as e:
108
- logger.error(f"Failed to load LLM model: {str(e)}", exc_info=True)
109
  raise
110
 
111
  def load_rag_model(self):
112
  """Load the RAG model"""
113
  try:
114
  logger.info(f"Loading RAG model: {self.rag_model_name}")
115
- self.rag_model = ToolRAGModel(self.rag_model_name)
 
 
 
116
  logger.info("RAG model loaded successfully")
117
  except Exception as e:
118
- logger.error(f"Failed to load RAG model: {str(e)}", exc_info=True)
119
- raise
120
-
121
- def load_tooluniverse(self):
122
- """Initialize the ToolUniverse"""
123
- try:
124
- logger.info("Loading ToolUniverse with files: %s", self.tool_files_dict)
125
- self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
126
- self.tooluniverse.load_tools()
127
-
128
- # Prepare special tools
129
- special_tools = self.tooluniverse.prepare_tool_prompts(
130
- self.tooluniverse.tool_category_dicts["special_tools"])
131
- self.special_tools_name = [tool['name'] for tool in special_tools]
132
-
133
- logger.info(f"ToolUniverse loaded with {len(self.special_tools_name)} special tools")
134
- except Exception as e:
135
- logger.error(f"Failed to load ToolUniverse: {str(e)}", exc_info=True)
136
  raise
137
 
138
  def chat(self, message: str, history: Optional[List[Dict]] = None,
@@ -176,179 +133,20 @@ class TxAgent:
176
  return response.strip()
177
 
178
  except Exception as e:
179
- logger.error(f"Chat failed: {str(e)}", exc_info=True)
180
  raise RuntimeError(f"Chat failed: {str(e)}")
181
 
182
- def run_multistep_agent(self, message: str, temperature: float = 0.7,
183
- max_new_tokens: int = 512, max_round: int = 5) -> str:
184
- """Run multi-step reasoning agent"""
185
- try:
186
- conversation = [{"role": "system", "content": self.prompt_multi_step}]
187
- conversation.append({"role": "user", "content": message})
188
-
189
- for round_num in range(1, max_round + 1):
190
- logger.info(f"Starting reasoning round {round_num}/{max_round}")
191
-
192
- # Generate next step
193
- inputs = self.tokenizer.apply_chat_template(
194
- conversation,
195
- add_generation_prompt=True,
196
- return_tensors="pt"
197
- ).to(self.device)
198
-
199
- generation_config = GenerationConfig(
200
- max_new_tokens=max_new_tokens,
201
- temperature=temperature,
202
- do_sample=True,
203
- pad_token_id=self.tokenizer.eos_token_id
204
- )
205
-
206
- outputs = self.model.generate(
207
- inputs,
208
- generation_config=generation_config
209
- )
210
-
211
- response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
212
-
213
- # Check for final answer
214
- if "[FinalAnswer]" in response:
215
- final_answer = response.split("[FinalAnswer]")[-1].strip()
216
- logger.info(f"Final answer found in round {round_num}")
217
- return final_answer
218
-
219
- # Add to conversation
220
- conversation.append({"role": "assistant", "content": response})
221
- logger.info(f"Round {round_num} completed without final answer")
222
-
223
- # If max rounds reached
224
- if self.force_finish:
225
- logger.info("Max rounds reached, forcing final answer")
226
- return self._force_final_answer(conversation, temperature, max_new_tokens)
227
-
228
- logger.warning("Max rounds reached without final answer")
229
- return "Reasoning rounds exceeded limit without reaching a final answer."
230
-
231
- except Exception as e:
232
- logger.error(f"Multi-step agent failed: {str(e)}", exc_info=True)
233
- raise RuntimeError(f"Multi-step agent failed: {str(e)}")
234
-
235
- def _force_final_answer(self, conversation: List[Dict], temperature: float, max_new_tokens: int) -> str:
236
- """Force a final answer when max rounds reached"""
237
- try:
238
- logger.info("Attempting to force final answer")
239
-
240
- # Add instruction to provide final answer
241
- conversation.append({
242
- "role": "user",
243
- "content": "Provide your final answer now based on all previous reasoning."
244
- })
245
-
246
- inputs = self.tokenizer.apply_chat_template(
247
- conversation,
248
- add_generation_prompt=True,
249
- return_tensors="pt"
250
- ).to(self.device)
251
-
252
- generation_config = GenerationConfig(
253
- max_new_tokens=max_new_tokens,
254
- temperature=temperature,
255
- do_sample=True,
256
- pad_token_id=self.tokenizer.eos_token_id
257
- )
258
-
259
- outputs = self.model.generate(
260
- inputs,
261
- generation_config=generation_config
262
- )
263
-
264
- response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
265
- return response.strip()
266
-
267
- except Exception as e:
268
- logger.error(f"Failed to force final answer: {str(e)}", exc_info=True)
269
- return "Failed to generate final answer."
270
-
271
- def extract_text_from_file(self, file_path: str) -> Optional[str]:
272
- """Extract text from PDF, CSV, or Excel files"""
273
- try:
274
- logger.info(f"Extracting text from file: {file_path}")
275
-
276
- if file_path.endswith('.pdf'):
277
- with pdfplumber.open(file_path) as pdf:
278
- text = "\n".join(
279
- page.extract_text()
280
- for page in pdf.pages
281
- if page.extract_text()
282
- )
283
- logger.info(f"Extracted {len(text)} characters from PDF")
284
- return text
285
-
286
- elif file_path.endswith('.csv'):
287
- df = pd.read_csv(file_path)
288
- text = df.to_string()
289
- logger.info(f"Extracted {len(text)} characters from CSV")
290
- return text
291
-
292
- elif file_path.endswith(('.xlsx', '.xls')):
293
- df = pd.read_excel(file_path)
294
- text = df.to_string()
295
- logger.info(f"Extracted {len(text)} characters from Excel")
296
- return text
297
-
298
- logger.warning(f"Unsupported file type: {file_path}")
299
- return None
300
-
301
- except Exception as e:
302
- logger.error(f"Text extraction failed: {str(e)}", exc_info=True)
303
- raise RuntimeError(f"Text extraction failed: {str(e)}")
304
-
305
- def analyze_text(self, text: str, max_tokens: int = 1000) -> str:
306
- """Analyze extracted text using the LLM"""
307
- try:
308
- logger.info(f"Analyzing text (first 100 chars): {text[:100]}...")
309
-
310
- prompt = f"""Analyze this medical document:
311
- 1. Diagnostic patterns
312
- 2. Medication issues
313
- 3. Recommended follow-ups
314
-
315
- Document:
316
- {text[:8000]} # Truncate to avoid token limits
317
- """
318
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
319
-
320
- generation_config = GenerationConfig(
321
- max_new_tokens=max_tokens,
322
- temperature=0.7,
323
- do_sample=True,
324
- pad_token_id=self.tokenizer.eos_token_id
325
- )
326
-
327
- outputs = self.model.generate(
328
- **inputs,
329
- generation_config=generation_config
330
- )
331
-
332
- analysis = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
333
- logger.info("Text analysis completed successfully")
334
- return analysis
335
-
336
- except Exception as e:
337
- logger.error(f"Text analysis failed: {str(e)}", exc_info=True)
338
- raise RuntimeError(f"Analysis failed: {str(e)}")
339
-
340
  def cleanup(self):
341
  """Clean up resources"""
342
  try:
343
- logger.info("Cleaning up TxAgent resources")
344
  if hasattr(self, 'model'):
345
  del self.model
346
  if hasattr(self, 'rag_model'):
347
  del self.rag_model
348
  torch.cuda.empty_cache()
349
- logger.info("TxAgent resources cleaned up")
350
  except Exception as e:
351
- logger.error(f"Cleanup failed: {str(e)}", exc_info=True)
352
  raise
353
 
354
  def __del__(self):
 
1
  import os
2
  import logging
3
  import torch
 
 
4
  from typing import Dict, Optional, List, Union
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
6
  from sentence_transformers import SentenceTransformer
 
 
7
 
8
+ # Configure logging for Hugging Face Spaces
9
  logging.basicConfig(
10
  level=logging.INFO,
11
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
 
 
 
 
12
  )
13
  logger = logging.getLogger("TxAgent")
14
 
 
19
  tool_files_dict: Optional[Dict] = None,
20
  enable_finish: bool = True,
21
  enable_rag: bool = False,
 
 
 
 
 
 
22
  force_finish: bool = True,
23
+ enable_checker: bool = True,
24
+ step_rag_num: int = 4,
25
+ seed: Optional[int] = None):
 
 
26
 
27
  # Initialization parameters
28
  self.model_name = model_name
 
30
  self.tool_files_dict = tool_files_dict or {}
31
  self.enable_finish = enable_finish
32
  self.enable_rag = enable_rag
 
 
 
 
 
 
33
  self.force_finish = force_finish
 
 
34
  self.enable_checker = enable_checker
35
+ self.step_rag_num = step_rag_num
36
+ self.seed = seed
37
 
38
  # Device setup
39
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
42
  self.model = None
43
  self.tokenizer = None
44
  self.rag_model = None
 
45
 
46
  # Prompts
 
 
47
  self.chat_prompt = "You are a helpful assistant for user chat."
48
 
49
+ logger.info(f"Initialized TxAgent with model: {model_name}")
50
 
51
  def init_model(self):
52
  """Initialize all models and components"""
53
  try:
54
  self.load_llm_model()
55
+ if self.enable_rag:
56
+ self.load_rag_model()
57
+ logger.info("Models initialized successfully")
58
  except Exception as e:
59
+ logger.error(f"Model initialization failed: {str(e)}")
60
  raise
61
 
62
  def load_llm_model(self):
 
65
  logger.info(f"Loading LLM model: {self.model_name}")
66
  self.tokenizer = AutoTokenizer.from_pretrained(
67
  self.model_name,
 
68
  trust_remote_code=True
69
  )
70
 
 
72
  self.model_name,
73
  torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
74
  device_map="auto",
 
75
  trust_remote_code=True
76
  )
77
  logger.info(f"LLM model loaded on {self.device}")
78
  except Exception as e:
79
+ logger.error(f"Failed to load LLM model: {str(e)}")
80
  raise
81
 
82
  def load_rag_model(self):
83
  """Load the RAG model"""
84
  try:
85
  logger.info(f"Loading RAG model: {self.rag_model_name}")
86
+ self.rag_model = SentenceTransformer(
87
+ self.rag_model_name,
88
+ device=str(self.device)
89
+ )
90
  logger.info("RAG model loaded successfully")
91
  except Exception as e:
92
+ logger.error(f"Failed to load RAG model: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  raise
94
 
95
  def chat(self, message: str, history: Optional[List[Dict]] = None,
 
133
  return response.strip()
134
 
135
  except Exception as e:
136
+ logger.error(f"Chat failed: {str(e)}")
137
  raise RuntimeError(f"Chat failed: {str(e)}")
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def cleanup(self):
140
  """Clean up resources"""
141
  try:
 
142
  if hasattr(self, 'model'):
143
  del self.model
144
  if hasattr(self, 'rag_model'):
145
  del self.rag_model
146
  torch.cuda.empty_cache()
147
+ logger.info("Resources cleaned up")
148
  except Exception as e:
149
+ logger.error(f"Cleanup failed: {str(e)}")
150
  raise
151
 
152
  def __del__(self):