Ali2206 commited on
Commit
698378b
·
verified ·
1 Parent(s): cf95a11

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +121 -38
src/txagent/txagent.py CHANGED
@@ -1,8 +1,9 @@
1
  import os
2
- import json
3
  import logging
4
  import torch
5
- from typing import List, Dict, Optional, Union
 
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from sentence_transformers import SentenceTransformer
8
 
@@ -15,56 +16,61 @@ class TxAgent:
15
  model_name: str,
16
  rag_model_name: str,
17
  tool_files_dict: Optional[Dict] = None,
18
- use_vllm: bool = False,
19
  force_finish: bool = True,
20
  enable_checker: bool = True,
21
  step_rag_num: int = 4,
22
  seed: Optional[int] = None):
 
 
23
 
 
 
 
 
 
 
 
 
 
24
  self.model_name = model_name
25
  self.rag_model_name = rag_model_name
26
  self.tool_files_dict = tool_files_dict or {}
27
- self.use_vllm = use_vllm
28
  self.force_finish = force_finish
29
  self.enable_checker = enable_checker
30
  self.step_rag_num = step_rag_num
31
  self.seed = seed
32
 
 
33
  self.model = None
34
  self.tokenizer = None
35
  self.rag_model = None
36
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
 
38
- logger.info(f"Initializing TxAgent with model: {model_name} on device: {self.device}")
39
 
40
  def init_model(self):
41
  """Initialize both the main model and RAG model."""
42
- self.load_models()
43
  self.load_rag_model()
44
  logger.info("Model initialization complete")
45
 
46
- def load_models(self):
47
  """Load the main LLM model."""
48
  try:
49
- logger.info(f"Loading model: {self.model_name}")
50
-
51
  self.tokenizer = AutoTokenizer.from_pretrained(
52
  self.model_name,
53
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
54
  )
55
-
56
  self.model = AutoModelForCausalLM.from_pretrained(
57
  self.model_name,
58
- torch_dtype=torch.float16,
59
  device_map="auto",
60
- cache_dir=os.environ.get("TRANSFORMERS_CACHE")
61
  )
62
-
63
- logger.info(f"Successfully loaded model on {self.device}")
64
-
65
  except Exception as e:
66
- logger.error(f"Failed to load model: {str(e)}")
67
- raise RuntimeError(f"Failed to load model: {str(e)}")
68
 
69
  def load_rag_model(self):
70
  """Load the RAG model."""
@@ -79,37 +85,114 @@ class TxAgent:
79
  logger.error(f"Failed to load RAG model: {str(e)}")
80
  raise RuntimeError(f"Failed to load RAG model: {str(e)}")
81
 
82
- def process_document(self, file_path: str) -> Dict:
83
- """Process a document and return analysis results."""
84
- try:
85
- # Extract text (implement your extraction logic)
86
- text = self.extract_text(file_path)
 
87
 
88
- # Process with LLM (implement your processing logic)
89
- result = self.analyze_text(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  return {
92
  "status": "success",
93
- "analysis": result,
94
  "model": self.model_name
95
  }
96
 
97
  except Exception as e:
98
- logger.error(f"Document processing failed: {str(e)}")
99
- raise RuntimeError(f"Document processing failed: {str(e)}")
 
 
 
 
100
 
101
- def extract_text(self, file_path: str) -> str:
102
- """Extract text from various file formats."""
103
- # Implement your text extraction logic here
104
- return "Sample extracted text"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- def analyze_text(self, text: str) -> str:
107
- """Analyze extracted text using the LLM."""
108
- # Implement your text analysis logic here
109
- return "Sample analysis result"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  def cleanup(self):
112
- """Clean up resources."""
113
  if hasattr(self, 'model'):
114
  del self.model
115
  if hasattr(self, 'rag_model'):
 
1
  import os
 
2
  import logging
3
  import torch
4
+ import pdfplumber
5
+ import pandas as pd
6
+ from typing import Dict, Optional, Union
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from sentence_transformers import SentenceTransformer
9
 
 
16
  model_name: str,
17
  rag_model_name: str,
18
  tool_files_dict: Optional[Dict] = None,
 
19
  force_finish: bool = True,
20
  enable_checker: bool = True,
21
  step_rag_num: int = 4,
22
  seed: Optional[int] = None):
23
+ """
24
+ Initialize the TxAgent with specified configuration.
25
 
26
+ Args:
27
+ model_name: Name/path of the main LLM model
28
+ rag_model_name: Name/path of the RAG model
29
+ tool_files_dict: Dictionary of tool files
30
+ force_finish: Whether to force finish when max tokens reached
31
+ enable_checker: Whether to enable reasoning trace checker
32
+ step_rag_num: Number of RAG tools to retrieve per step
33
+ seed: Random seed for reproducibility
34
+ """
35
  self.model_name = model_name
36
  self.rag_model_name = rag_model_name
37
  self.tool_files_dict = tool_files_dict or {}
 
38
  self.force_finish = force_finish
39
  self.enable_checker = enable_checker
40
  self.step_rag_num = step_rag_num
41
  self.seed = seed
42
 
43
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
  self.model = None
45
  self.tokenizer = None
46
  self.rag_model = None
 
47
 
48
+ logger.info(f"Initialized TxAgent with model: {model_name} on device: {self.device}")
49
 
50
  def init_model(self):
51
  """Initialize both the main model and RAG model."""
52
+ self.load_llm_model()
53
  self.load_rag_model()
54
  logger.info("Model initialization complete")
55
 
56
+ def load_llm_model(self):
57
  """Load the main LLM model."""
58
  try:
59
+ logger.info(f"Loading LLM model: {self.model_name}")
 
60
  self.tokenizer = AutoTokenizer.from_pretrained(
61
  self.model_name,
62
+ cache_dir=os.getenv("TRANSFORMERS_CACHE")
63
  )
 
64
  self.model = AutoModelForCausalLM.from_pretrained(
65
  self.model_name,
66
+ torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
67
  device_map="auto",
68
+ cache_dir=os.getenv("TRANSFORMERS_CACHE")
69
  )
70
+ logger.info(f"LLM model loaded on {self.device}")
 
 
71
  except Exception as e:
72
+ logger.error(f"Failed to load LLM model: {str(e)}")
73
+ raise RuntimeError(f"Failed to load LLM model: {str(e)}")
74
 
75
  def load_rag_model(self):
76
  """Load the RAG model."""
 
85
  logger.error(f"Failed to load RAG model: {str(e)}")
86
  raise RuntimeError(f"Failed to load RAG model: {str(e)}")
87
 
88
+ def process_document(self, file_path: str) -> Dict[str, Union[str, Dict]]:
89
+ """
90
+ Process a medical document and return analysis results.
91
+
92
+ Args:
93
+ file_path: Path to the document file (PDF, CSV, or Excel)
94
 
95
+ Returns:
96
+ Dictionary containing:
97
+ - status: "success" or "error"
98
+ - analysis: Detailed analysis results or error message
99
+ - model: Model used for analysis
100
+ """
101
+ try:
102
+ # 1. Extract text from document
103
+ text = self.extract_text_from_file(file_path)
104
+ if not text:
105
+ return {
106
+ "status": "error",
107
+ "message": "Failed to extract text - unsupported file type or empty document",
108
+ "model": self.model_name
109
+ }
110
+
111
+ # 2. Analyze with LLM
112
+ analysis = self.analyze_text(text)
113
 
114
  return {
115
  "status": "success",
116
+ "analysis": analysis,
117
  "model": self.model_name
118
  }
119
 
120
  except Exception as e:
121
+ logger.error(f"Document processing failed: {str(e)}", exc_info=True)
122
+ return {
123
+ "status": "error",
124
+ "message": f"Processing error: {str(e)}",
125
+ "model": self.model_name
126
+ }
127
 
128
+ def extract_text_from_file(self, file_path: str) -> Optional[str]:
129
+ """
130
+ Extract text from supported file types (PDF, CSV, Excel).
131
+
132
+ Args:
133
+ file_path: Path to the input file
134
+
135
+ Returns:
136
+ Extracted text as string, or None if extraction fails
137
+ """
138
+ try:
139
+ if file_path.endswith('.pdf'):
140
+ with pdfplumber.open(file_path) as pdf:
141
+ return "\n".join(
142
+ page.extract_text()
143
+ for page in pdf.pages
144
+ if page.extract_text()
145
+ )
146
+
147
+ elif file_path.endswith('.csv'):
148
+ df = pd.read_csv(file_path)
149
+ return df.to_string()
150
+
151
+ elif file_path.endswith(('.xlsx', '.xls')):
152
+ df = pd.read_excel(file_path)
153
+ return df.to_string()
154
+
155
+ logger.warning(f"Unsupported file type: {file_path}")
156
+ return None
157
+
158
+ except Exception as e:
159
+ logger.error(f"Text extraction failed: {str(e)}")
160
+ raise RuntimeError(f"Text extraction failed: {str(e)}")
161
 
162
+ def analyze_text(self, text: str, max_tokens: int = 1000) -> str:
163
+ """
164
+ Analyze extracted text using the LLM.
165
+
166
+ Args:
167
+ text: Text to analyze
168
+ max_tokens: Maximum tokens to generate
169
+
170
+ Returns:
171
+ Analysis results as string
172
+ """
173
+ try:
174
+ prompt = f"""Analyze this medical document and provide:
175
+ 1. Diagnostic patterns
176
+ 2. Medication issues
177
+ 3. Recommended follow-ups
178
+
179
+ Document:
180
+ {text[:8000]} # Truncate to avoid token limits
181
+ """
182
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
183
+ outputs = self.model.generate(
184
+ **inputs,
185
+ max_new_tokens=max_tokens,
186
+ pad_token_id=self.tokenizer.eos_token_id
187
+ )
188
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
189
+
190
+ except Exception as e:
191
+ logger.error(f"Text analysis failed: {str(e)}")
192
+ raise RuntimeError(f"Analysis failed: {str(e)}")
193
 
194
  def cleanup(self):
195
+ """Clean up resources and clear memory."""
196
  if hasattr(self, 'model'):
197
  del self.model
198
  if hasattr(self, 'rag_model'):