Ali2206 commited on
Commit
3cfe99a
·
verified ·
1 Parent(s): 7e095f4

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +117 -22
src/txagent/txagent.py CHANGED
@@ -1,15 +1,23 @@
1
- # txagent.py - Core TxAgent class (simplified but maintains key functionality)
2
  import os
3
  import logging
4
  import torch
5
- import json
 
6
  from typing import Dict, Optional, List, Union
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
8
  from sentence_transformers import SentenceTransformer
9
  from tooluniverse import ToolUniverse
10
  from .toolrag import ToolRAGModel
11
 
12
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
 
 
 
 
 
 
 
13
  logger = logging.getLogger("TxAgent")
14
 
15
  class TxAgent:
@@ -69,10 +77,14 @@ class TxAgent:
69
 
70
  def init_model(self):
71
  """Initialize all models and components"""
72
- self.load_llm_model()
73
- self.load_rag_model()
74
- self.load_tooluniverse()
75
- logger.info("All models initialized successfully")
 
 
 
 
76
 
77
  def load_llm_model(self):
78
  """Load the main LLM model"""
@@ -93,7 +105,7 @@ class TxAgent:
93
  )
94
  logger.info(f"LLM model loaded on {self.device}")
95
  except Exception as e:
96
- logger.error(f"Failed to load LLM model: {str(e)}")
97
  raise
98
 
99
  def load_rag_model(self):
@@ -103,13 +115,13 @@ class TxAgent:
103
  self.rag_model = ToolRAGModel(self.rag_model_name)
104
  logger.info("RAG model loaded successfully")
105
  except Exception as e:
106
- logger.error(f"Failed to load RAG model: {str(e)}")
107
  raise
108
 
109
  def load_tooluniverse(self):
110
  """Initialize the ToolUniverse"""
111
  try:
112
- logger.info("Loading ToolUniverse")
113
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
114
  self.tooluniverse.load_tools()
115
 
@@ -120,7 +132,7 @@ class TxAgent:
120
 
121
  logger.info(f"ToolUniverse loaded with {len(self.special_tools_name)} special tools")
122
  except Exception as e:
123
- logger.error(f"Failed to load ToolUniverse: {str(e)}")
124
  raise
125
 
126
  def chat(self, message: str, history: Optional[List[Dict]] = None,
@@ -164,7 +176,7 @@ class TxAgent:
164
  return response.strip()
165
 
166
  except Exception as e:
167
- logger.error(f"Chat failed: {str(e)}")
168
  raise RuntimeError(f"Chat failed: {str(e)}")
169
 
170
  def run_multistep_agent(self, message: str, temperature: float = 0.7,
@@ -174,7 +186,9 @@ class TxAgent:
174
  conversation = [{"role": "system", "content": self.prompt_multi_step}]
175
  conversation.append({"role": "user", "content": message})
176
 
177
- for _ in range(max_round):
 
 
178
  # Generate next step
179
  inputs = self.tokenizer.apply_chat_template(
180
  conversation,
@@ -198,24 +212,31 @@ class TxAgent:
198
 
199
  # Check for final answer
200
  if "[FinalAnswer]" in response:
201
- return response.split("[FinalAnswer]")[-1].strip()
 
 
202
 
203
  # Add to conversation
204
  conversation.append({"role": "assistant", "content": response})
 
205
 
206
  # If max rounds reached
207
  if self.force_finish:
 
208
  return self._force_final_answer(conversation, temperature, max_new_tokens)
209
 
 
210
  return "Reasoning rounds exceeded limit without reaching a final answer."
211
 
212
  except Exception as e:
213
- logger.error(f"Multi-step agent failed: {str(e)}")
214
  raise RuntimeError(f"Multi-step agent failed: {str(e)}")
215
 
216
  def _force_final_answer(self, conversation: List[Dict], temperature: float, max_new_tokens: int) -> str:
217
  """Force a final answer when max rounds reached"""
218
  try:
 
 
219
  # Add instruction to provide final answer
220
  conversation.append({
221
  "role": "user",
@@ -244,17 +265,91 @@ class TxAgent:
244
  return response.strip()
245
 
246
  except Exception as e:
247
- logger.error(f"Failed to force final answer: {str(e)}")
248
  return "Failed to generate final answer."
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  def cleanup(self):
251
  """Clean up resources"""
252
- if hasattr(self, 'model'):
253
- del self.model
254
- if hasattr(self, 'rag_model'):
255
- del self.rag_model
256
- torch.cuda.empty_cache()
257
- logger.info("TxAgent resources cleaned up")
 
 
 
 
 
258
 
259
  def __del__(self):
260
  """Destructor to ensure proper cleanup"""
 
 
1
  import os
2
  import logging
3
  import torch
4
+ import pdfplumber
5
+ import pandas as pd
6
  from typing import Dict, Optional, List, Union
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
8
  from sentence_transformers import SentenceTransformer
9
  from tooluniverse import ToolUniverse
10
  from .toolrag import ToolRAGModel
11
 
12
+ # Configure logging
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
16
+ handlers=[
17
+ logging.StreamHandler(),
18
+ logging.FileHandler('txagent_core.log')
19
+ ]
20
+ )
21
  logger = logging.getLogger("TxAgent")
22
 
23
  class TxAgent:
 
77
 
78
  def init_model(self):
79
  """Initialize all models and components"""
80
+ try:
81
+ self.load_llm_model()
82
+ self.load_rag_model()
83
+ self.load_tooluniverse()
84
+ logger.info("All models initialized successfully")
85
+ except Exception as e:
86
+ logger.error(f"Model initialization failed: {str(e)}", exc_info=True)
87
+ raise
88
 
89
  def load_llm_model(self):
90
  """Load the main LLM model"""
 
105
  )
106
  logger.info(f"LLM model loaded on {self.device}")
107
  except Exception as e:
108
+ logger.error(f"Failed to load LLM model: {str(e)}", exc_info=True)
109
  raise
110
 
111
  def load_rag_model(self):
 
115
  self.rag_model = ToolRAGModel(self.rag_model_name)
116
  logger.info("RAG model loaded successfully")
117
  except Exception as e:
118
+ logger.error(f"Failed to load RAG model: {str(e)}", exc_info=True)
119
  raise
120
 
121
  def load_tooluniverse(self):
122
  """Initialize the ToolUniverse"""
123
  try:
124
+ logger.info("Loading ToolUniverse with files: %s", self.tool_files_dict)
125
  self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
126
  self.tooluniverse.load_tools()
127
 
 
132
 
133
  logger.info(f"ToolUniverse loaded with {len(self.special_tools_name)} special tools")
134
  except Exception as e:
135
+ logger.error(f"Failed to load ToolUniverse: {str(e)}", exc_info=True)
136
  raise
137
 
138
  def chat(self, message: str, history: Optional[List[Dict]] = None,
 
176
  return response.strip()
177
 
178
  except Exception as e:
179
+ logger.error(f"Chat failed: {str(e)}", exc_info=True)
180
  raise RuntimeError(f"Chat failed: {str(e)}")
181
 
182
  def run_multistep_agent(self, message: str, temperature: float = 0.7,
 
186
  conversation = [{"role": "system", "content": self.prompt_multi_step}]
187
  conversation.append({"role": "user", "content": message})
188
 
189
+ for round_num in range(1, max_round + 1):
190
+ logger.info(f"Starting reasoning round {round_num}/{max_round}")
191
+
192
  # Generate next step
193
  inputs = self.tokenizer.apply_chat_template(
194
  conversation,
 
212
 
213
  # Check for final answer
214
  if "[FinalAnswer]" in response:
215
+ final_answer = response.split("[FinalAnswer]")[-1].strip()
216
+ logger.info(f"Final answer found in round {round_num}")
217
+ return final_answer
218
 
219
  # Add to conversation
220
  conversation.append({"role": "assistant", "content": response})
221
+ logger.info(f"Round {round_num} completed without final answer")
222
 
223
  # If max rounds reached
224
  if self.force_finish:
225
+ logger.info("Max rounds reached, forcing final answer")
226
  return self._force_final_answer(conversation, temperature, max_new_tokens)
227
 
228
+ logger.warning("Max rounds reached without final answer")
229
  return "Reasoning rounds exceeded limit without reaching a final answer."
230
 
231
  except Exception as e:
232
+ logger.error(f"Multi-step agent failed: {str(e)}", exc_info=True)
233
  raise RuntimeError(f"Multi-step agent failed: {str(e)}")
234
 
235
  def _force_final_answer(self, conversation: List[Dict], temperature: float, max_new_tokens: int) -> str:
236
  """Force a final answer when max rounds reached"""
237
  try:
238
+ logger.info("Attempting to force final answer")
239
+
240
  # Add instruction to provide final answer
241
  conversation.append({
242
  "role": "user",
 
265
  return response.strip()
266
 
267
  except Exception as e:
268
+ logger.error(f"Failed to force final answer: {str(e)}", exc_info=True)
269
  return "Failed to generate final answer."
270
 
271
+ def extract_text_from_file(self, file_path: str) -> Optional[str]:
272
+ """Extract text from PDF, CSV, or Excel files"""
273
+ try:
274
+ logger.info(f"Extracting text from file: {file_path}")
275
+
276
+ if file_path.endswith('.pdf'):
277
+ with pdfplumber.open(file_path) as pdf:
278
+ text = "\n".join(
279
+ page.extract_text()
280
+ for page in pdf.pages
281
+ if page.extract_text()
282
+ )
283
+ logger.info(f"Extracted {len(text)} characters from PDF")
284
+ return text
285
+
286
+ elif file_path.endswith('.csv'):
287
+ df = pd.read_csv(file_path)
288
+ text = df.to_string()
289
+ logger.info(f"Extracted {len(text)} characters from CSV")
290
+ return text
291
+
292
+ elif file_path.endswith(('.xlsx', '.xls')):
293
+ df = pd.read_excel(file_path)
294
+ text = df.to_string()
295
+ logger.info(f"Extracted {len(text)} characters from Excel")
296
+ return text
297
+
298
+ logger.warning(f"Unsupported file type: {file_path}")
299
+ return None
300
+
301
+ except Exception as e:
302
+ logger.error(f"Text extraction failed: {str(e)}", exc_info=True)
303
+ raise RuntimeError(f"Text extraction failed: {str(e)}")
304
+
305
+ def analyze_text(self, text: str, max_tokens: int = 1000) -> str:
306
+ """Analyze extracted text using the LLM"""
307
+ try:
308
+ logger.info(f"Analyzing text (first 100 chars): {text[:100]}...")
309
+
310
+ prompt = f"""Analyze this medical document:
311
+ 1. Diagnostic patterns
312
+ 2. Medication issues
313
+ 3. Recommended follow-ups
314
+
315
+ Document:
316
+ {text[:8000]} # Truncate to avoid token limits
317
+ """
318
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
319
+
320
+ generation_config = GenerationConfig(
321
+ max_new_tokens=max_tokens,
322
+ temperature=0.7,
323
+ do_sample=True,
324
+ pad_token_id=self.tokenizer.eos_token_id
325
+ )
326
+
327
+ outputs = self.model.generate(
328
+ **inputs,
329
+ generation_config=generation_config
330
+ )
331
+
332
+ analysis = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
333
+ logger.info("Text analysis completed successfully")
334
+ return analysis
335
+
336
+ except Exception as e:
337
+ logger.error(f"Text analysis failed: {str(e)}", exc_info=True)
338
+ raise RuntimeError(f"Analysis failed: {str(e)}")
339
+
340
  def cleanup(self):
341
  """Clean up resources"""
342
+ try:
343
+ logger.info("Cleaning up TxAgent resources")
344
+ if hasattr(self, 'model'):
345
+ del self.model
346
+ if hasattr(self, 'rag_model'):
347
+ del self.rag_model
348
+ torch.cuda.empty_cache()
349
+ logger.info("TxAgent resources cleaned up")
350
+ except Exception as e:
351
+ logger.error(f"Cleanup failed: {str(e)}", exc_info=True)
352
+ raise
353
 
354
  def __del__(self):
355
  """Destructor to ensure proper cleanup"""