Ali2206 commited on
Commit
f7f7fed
·
verified ·
1 Parent(s): 698378b

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +13 -55
src/txagent/txagent.py CHANGED
@@ -7,7 +7,6 @@ from typing import Dict, Optional, Union
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from sentence_transformers import SentenceTransformer
9
 
10
- # Configure logging
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
12
  logger = logging.getLogger("TxAgent")
13
 
@@ -20,18 +19,7 @@ class TxAgent:
20
  enable_checker: bool = True,
21
  step_rag_num: int = 4,
22
  seed: Optional[int] = None):
23
- """
24
- Initialize the TxAgent with specified configuration.
25
-
26
- Args:
27
- model_name: Name/path of the main LLM model
28
- rag_model_name: Name/path of the RAG model
29
- tool_files_dict: Dictionary of tool files
30
- force_finish: Whether to force finish when max tokens reached
31
- enable_checker: Whether to enable reasoning trace checker
32
- step_rag_num: Number of RAG tools to retrieve per step
33
- seed: Random seed for reproducibility
34
- """
35
  self.model_name = model_name
36
  self.rag_model_name = rag_model_name
37
  self.tool_files_dict = tool_files_dict or {}
@@ -48,24 +36,24 @@ class TxAgent:
48
  logger.info(f"Initialized TxAgent with model: {model_name} on device: {self.device}")
49
 
50
  def init_model(self):
51
- """Initialize both the main model and RAG model."""
52
  self.load_llm_model()
53
  self.load_rag_model()
54
  logger.info("Model initialization complete")
55
 
56
  def load_llm_model(self):
57
- """Load the main LLM model."""
58
  try:
59
  logger.info(f"Loading LLM model: {self.model_name}")
60
  self.tokenizer = AutoTokenizer.from_pretrained(
61
  self.model_name,
62
- cache_dir=os.getenv("TRANSFORMERS_CACHE")
63
  )
64
  self.model = AutoModelForCausalLM.from_pretrained(
65
  self.model_name,
66
  torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
67
  device_map="auto",
68
- cache_dir=os.getenv("TRANSFORMERS_CACHE")
69
  )
70
  logger.info(f"LLM model loaded on {self.device}")
71
  except Exception as e:
@@ -86,29 +74,16 @@ class TxAgent:
86
  raise RuntimeError(f"Failed to load RAG model: {str(e)}")
87
 
88
  def process_document(self, file_path: str) -> Dict[str, Union[str, Dict]]:
89
- """
90
- Process a medical document and return analysis results.
91
-
92
- Args:
93
- file_path: Path to the document file (PDF, CSV, or Excel)
94
-
95
- Returns:
96
- Dictionary containing:
97
- - status: "success" or "error"
98
- - analysis: Detailed analysis results or error message
99
- - model: Model used for analysis
100
- """
101
  try:
102
- # 1. Extract text from document
103
  text = self.extract_text_from_file(file_path)
104
  if not text:
105
  return {
106
  "status": "error",
107
- "message": "Failed to extract text - unsupported file type or empty document",
108
  "model": self.model_name
109
  }
110
 
111
- # 2. Analyze with LLM
112
  analysis = self.analyze_text(text)
113
 
114
  return {
@@ -118,23 +93,15 @@ class TxAgent:
118
  }
119
 
120
  except Exception as e:
121
- logger.error(f"Document processing failed: {str(e)}", exc_info=True)
122
  return {
123
  "status": "error",
124
- "message": f"Processing error: {str(e)}",
125
  "model": self.model_name
126
  }
127
 
128
  def extract_text_from_file(self, file_path: str) -> Optional[str]:
129
- """
130
- Extract text from supported file types (PDF, CSV, Excel).
131
-
132
- Args:
133
- file_path: Path to the input file
134
-
135
- Returns:
136
- Extracted text as string, or None if extraction fails
137
- """
138
  try:
139
  if file_path.endswith('.pdf'):
140
  with pdfplumber.open(file_path) as pdf:
@@ -160,18 +127,9 @@ class TxAgent:
160
  raise RuntimeError(f"Text extraction failed: {str(e)}")
161
 
162
  def analyze_text(self, text: str, max_tokens: int = 1000) -> str:
163
- """
164
- Analyze extracted text using the LLM.
165
-
166
- Args:
167
- text: Text to analyze
168
- max_tokens: Maximum tokens to generate
169
-
170
- Returns:
171
- Analysis results as string
172
- """
173
  try:
174
- prompt = f"""Analyze this medical document and provide:
175
  1. Diagnostic patterns
176
  2. Medication issues
177
  3. Recommended follow-ups
@@ -192,7 +150,7 @@ Document:
192
  raise RuntimeError(f"Analysis failed: {str(e)}")
193
 
194
  def cleanup(self):
195
- """Clean up resources and clear memory."""
196
  if hasattr(self, 'model'):
197
  del self.model
198
  if hasattr(self, 'rag_model'):
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from sentence_transformers import SentenceTransformer
9
 
 
10
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
11
  logger = logging.getLogger("TxAgent")
12
 
 
19
  enable_checker: bool = True,
20
  step_rag_num: int = 4,
21
  seed: Optional[int] = None):
22
+ """Initialize TxAgent without vLLM dependencies."""
 
 
 
 
 
 
 
 
 
 
 
23
  self.model_name = model_name
24
  self.rag_model_name = rag_model_name
25
  self.tool_files_dict = tool_files_dict or {}
 
36
  logger.info(f"Initialized TxAgent with model: {model_name} on device: {self.device}")
37
 
38
  def init_model(self):
39
+ """Initialize models using transformers only."""
40
  self.load_llm_model()
41
  self.load_rag_model()
42
  logger.info("Model initialization complete")
43
 
44
  def load_llm_model(self):
45
+ """Load the main LLM model using transformers."""
46
  try:
47
  logger.info(f"Loading LLM model: {self.model_name}")
48
  self.tokenizer = AutoTokenizer.from_pretrained(
49
  self.model_name,
50
+ cache_dir=os.getenv("HF_HOME")
51
  )
52
  self.model = AutoModelForCausalLM.from_pretrained(
53
  self.model_name,
54
  torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
55
  device_map="auto",
56
+ cache_dir=os.getenv("HF_HOME")
57
  )
58
  logger.info(f"LLM model loaded on {self.device}")
59
  except Exception as e:
 
74
  raise RuntimeError(f"Failed to load RAG model: {str(e)}")
75
 
76
  def process_document(self, file_path: str) -> Dict[str, Union[str, Dict]]:
77
+ """Process a document and return real analysis results."""
 
 
 
 
 
 
 
 
 
 
 
78
  try:
 
79
  text = self.extract_text_from_file(file_path)
80
  if not text:
81
  return {
82
  "status": "error",
83
+ "message": "Failed to extract text",
84
  "model": self.model_name
85
  }
86
 
 
87
  analysis = self.analyze_text(text)
88
 
89
  return {
 
93
  }
94
 
95
  except Exception as e:
96
+ logger.error(f"Document processing failed: {str(e)}")
97
  return {
98
  "status": "error",
99
+ "message": str(e),
100
  "model": self.model_name
101
  }
102
 
103
  def extract_text_from_file(self, file_path: str) -> Optional[str]:
104
+ """Extract text from PDF, CSV, or Excel files."""
 
 
 
 
 
 
 
 
105
  try:
106
  if file_path.endswith('.pdf'):
107
  with pdfplumber.open(file_path) as pdf:
 
127
  raise RuntimeError(f"Text extraction failed: {str(e)}")
128
 
129
  def analyze_text(self, text: str, max_tokens: int = 1000) -> str:
130
+ """Analyze extracted text using the LLM."""
 
 
 
 
 
 
 
 
 
131
  try:
132
+ prompt = f"""Analyze this medical document:
133
  1. Diagnostic patterns
134
  2. Medication issues
135
  3. Recommended follow-ups
 
150
  raise RuntimeError(f"Analysis failed: {str(e)}")
151
 
152
  def cleanup(self):
153
+ """Clean up resources."""
154
  if hasattr(self, 'model'):
155
  del self.model
156
  if hasattr(self, 'rag_model'):