Ali2206 commited on
Commit
bae0943
·
verified ·
1 Parent(s): 31f2bf3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -107
app.py CHANGED
@@ -4,19 +4,23 @@ import logging
4
  import torch
5
  import gradio as gr
6
  from tooluniverse import ToolUniverse
7
- from txagent import TxAgent # Updated import statement
8
  import warnings
9
  from typing import List, Dict, Any
10
 
11
  # Suppress specific warnings
12
  warnings.filterwarnings("ignore", category=UserWarning)
13
 
14
- # Configuration with hardcoded embedding file
15
  CONFIG = {
16
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
17
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
18
  "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt",
19
  "tool_files": {
 
 
 
 
20
  "new_tool": "./data/new_tool.json"
21
  }
22
  }
@@ -34,7 +38,7 @@ def prepare_tool_files():
34
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
35
  logger.info("Generating tool list using ToolUniverse...")
36
  tu = ToolUniverse()
37
- tools = tu.get_all_tools()
38
  with open(CONFIG["tool_files"]["new_tool"], "w") as f:
39
  json.dump(tools, f, indent=2)
40
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
@@ -46,140 +50,134 @@ def safe_load_embeddings(filepath: str) -> Any:
46
  return torch.load(filepath, weights_only=True)
47
  except Exception as e:
48
  logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
49
- # If that fails, try with weights_only=False (less secure)
50
- return torch.load(filepath, weights_only=False)
51
-
52
- def patch_embedding_loading():
53
- """Monkey-patch the embedding loading functionality"""
54
- try:
55
- from txagent.toolrag import ToolRAGModel
56
-
57
- original_load = ToolRAGModel.load_tool_desc_embedding
58
-
59
- def patched_load(self, tooluniverse: ToolUniverse) -> bool:
60
- try:
61
- if not os.path.exists(CONFIG["embedding_filename"]):
62
- logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
63
- return False
64
-
65
- # Load embeddings safely
66
- self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
67
-
68
- # Handle tool count mismatch
69
- tools = tooluniverse.get_all_tools()
70
- current_count = len(tools)
71
- embedding_count = len(self.tool_desc_embedding)
72
-
73
- if current_count != embedding_count:
74
- logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
75
-
76
- if current_count < embedding_count:
77
- self.tool_desc_embedding = self.tool_desc_embedding[:current_count]
78
- logger.info(f"Truncated embeddings to match {current_count} tools")
79
- else:
80
- last_embedding = self.tool_desc_embedding[-1]
81
- padding = [last_embedding] * (current_count - embedding_count)
82
- self.tool_desc_embedding = torch.cat(
83
- [self.tool_desc_embedding] + padding
84
- )
85
- logger.info(f"Padded embeddings to match {current_count} tools")
86
-
87
- return True
88
-
89
- except Exception as e:
90
- logger.error(f"Failed to load embeddings: {str(e)}")
91
- return False
92
-
93
- # Apply the patch
94
- ToolRAGModel.load_tool_desc_embedding = patched_load
95
- logger.info("Successfully patched embedding loading")
96
-
97
- except Exception as e:
98
- logger.error(f"Failed to patch embedding loading: {str(e)}")
99
- raise
100
 
101
- class TxAgentApp:
102
  def __init__(self):
103
- self.agent = None
 
 
 
104
  self.is_initialized = False
 
105
 
106
  def initialize(self) -> str:
107
- """Initialize the TxAgent with all required components"""
108
  if self.is_initialized:
109
  return "✅ Already initialized"
110
 
111
  try:
112
- # Apply our patch before initialization
113
- patch_embedding_loading()
 
 
 
 
 
 
 
 
114
 
115
- logger.info("Initializing TxAgent...")
116
- self.agent = TxAgent(
117
- model_name=CONFIG["model_name"],
118
- rag_model_name=CONFIG["rag_model_name"],
119
- tool_files_dict=CONFIG["tool_files"],
120
- force_finish=True,
121
- enable_checker=True,
122
- step_rag_num=10,
123
- seed=100,
124
- additional_default_tools=["DirectResponse", "RequireClarification"]
125
  )
126
 
127
- logger.info("Loading models...")
128
- self.agent.init_model()
 
 
 
129
 
130
  self.is_initialized = True
131
- return "✅ TxAgent initialized successfully"
132
 
133
  except Exception as e:
134
  logger.error(f"Initialization failed: {str(e)}")
135
  return f"❌ Initialization failed: {str(e)}"
136
 
137
  def chat(self, message: str, history: List[List[str]]) -> List[List[str]]:
138
- """
139
- Handle chat interactions with the TxAgent
140
-
141
- Args:
142
- message: User input message
143
- history: Chat history in format [[user_msg, bot_msg], ...]
144
-
145
- Returns:
146
- Updated chat history
147
- """
148
  if not self.is_initialized:
149
  return history + [["", "⚠️ Please initialize the model first"]]
150
 
151
  try:
152
- # Convert history to the format TxAgent expects
153
- tx_history = []
154
- for user_msg, bot_msg in history:
155
- tx_history.append({"role": "user", "content": user_msg})
156
- if bot_msg: # Only add bot response if it exists
157
- tx_history.append({"role": "assistant", "content": bot_msg})
 
 
 
 
 
 
158
 
159
  # Generate response
160
- response = ""
161
- for chunk in self.agent.run_gradio_chat(
162
- message=message,
163
- history=tx_history,
164
- temperature=0.3,
 
 
 
165
  max_new_tokens=1024,
166
- max_token=8192,
167
- call_agent=False,
168
- conversation=None,
169
- max_round=30
170
- ):
171
- response = chunk # Get the final response
 
 
172
 
173
- # Format response for Gradio Chatbot
174
  return history + [[message, response]]
175
-
176
  except Exception as e:
177
  logger.error(f"Chat error: {str(e)}")
178
  return history + [["", f"Error: {str(e)}"]]
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  def create_interface() -> gr.Blocks:
181
  """Create the Gradio interface"""
182
- app = TxAgentApp()
183
 
184
  with gr.Blocks(
185
  title="TxAgent",
@@ -189,7 +187,7 @@ def create_interface() -> gr.Blocks:
189
  ) as demo:
190
  gr.Markdown("""
191
  # 🧠 TxAgent: Therapeutic Reasoning AI
192
- ### (Using pre-loaded embeddings)
193
  """)
194
 
195
  with gr.Row():
@@ -212,9 +210,8 @@ def create_interface() -> gr.Blocks:
212
  inputs=msg
213
  )
214
 
215
- def wrapper_initialize() -> tuple:
216
- """Wrapper for initialization with UI updates"""
217
- status = app.initialize()
218
  return status, gr.update(interactive=False)
219
 
220
  init_btn.click(
@@ -223,7 +220,7 @@ def create_interface() -> gr.Blocks:
223
  )
224
 
225
  msg.submit(
226
- fn=app.chat,
227
  inputs=[msg, chatbot],
228
  outputs=chatbot
229
  ).then(
 
4
  import torch
5
  import gradio as gr
6
  from tooluniverse import ToolUniverse
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
  import warnings
9
  from typing import List, Dict, Any
10
 
11
  # Suppress specific warnings
12
  warnings.filterwarnings("ignore", category=UserWarning)
13
 
14
+ # Configuration
15
  CONFIG = {
16
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
17
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
18
  "embedding_filename": "ToolRAG-T1-GTE-Qwen2-1.5Btool_embedding_47dc56b3e3ddeb31af4f19defdd538d984de1500368852a0fab80bc2e826c944.pt",
19
  "tool_files": {
20
+ "opentarget": "opentarget_tools.json",
21
+ "fda_drug_label": "fda_drug_labeling_tools.json",
22
+ "special_tools": "special_tools.json",
23
+ "monarch": "monarch_tools.json",
24
  "new_tool": "./data/new_tool.json"
25
  }
26
  }
 
38
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
39
  logger.info("Generating tool list using ToolUniverse...")
40
  tu = ToolUniverse()
41
+ tools = tu.get_all_tools() if hasattr(tu, 'get_all_tools') else []
42
  with open(CONFIG["tool_files"]["new_tool"], "w") as f:
43
  json.dump(tools, f, indent=2)
44
  logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
 
50
  return torch.load(filepath, weights_only=True)
51
  except Exception as e:
52
  logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
53
+ try:
54
+ # Try with the safe_globals context manager
55
+ with torch.serialization.safe_globals([torch.serialization._reconstruct]):
56
+ return torch.load(filepath, weights_only=False)
57
+ except Exception as e:
58
+ logger.error(f"Failed to load embeddings even with safe_globals: {str(e)}")
59
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ class TxAgentWrapper:
62
  def __init__(self):
63
+ self.model = None
64
+ self.tokenizer = None
65
+ self.rag_model = None
66
+ self.tooluniverse = None
67
  self.is_initialized = False
68
+ self.special_tools = ['Finish', 'Tool_RAG', 'DirectResponse', 'RequireClarification']
69
 
70
  def initialize(self) -> str:
71
+ """Initialize the model from Hugging Face"""
72
  if self.is_initialized:
73
  return "✅ Already initialized"
74
 
75
  try:
76
+ logger.info("Loading models from Hugging Face Hub...")
77
+
78
+ # Initialize ToolUniverse first
79
+ self.tooluniverse = ToolUniverse(tool_files=CONFIG["tool_files"])
80
+ if hasattr(self.tooluniverse, 'load_tools'):
81
+ self.tooluniverse.load_tools()
82
+ logger.info(f"Loaded {len(self.tooluniverse.tools)} tools")
83
+ else:
84
+ logger.error("ToolUniverse doesn't have load_tools method")
85
+ return "❌ Failed to load tools"
86
 
87
+ # Load main model
88
+ self.tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_name"])
89
+ self.model = AutoModelForCausalLM.from_pretrained(
90
+ CONFIG["model_name"],
91
+ device_map="auto",
92
+ torch_dtype=torch.float16
 
 
 
 
93
  )
94
 
95
+ # Load embeddings if file exists
96
+ if os.path.exists(CONFIG["embedding_filename"]):
97
+ self.rag_model = safe_load_embeddings(CONFIG["embedding_filename"])
98
+ if self.rag_model is None:
99
+ return "❌ Failed to load embeddings"
100
 
101
  self.is_initialized = True
102
+ return "✅ Model initialized successfully"
103
 
104
  except Exception as e:
105
  logger.error(f"Initialization failed: {str(e)}")
106
  return f"❌ Initialization failed: {str(e)}"
107
 
108
  def chat(self, message: str, history: List[List[str]]) -> List[List[str]]:
109
+ """Handle chat interactions with the model"""
 
 
 
 
 
 
 
 
 
110
  if not self.is_initialized:
111
  return history + [["", "⚠️ Please initialize the model first"]]
112
 
113
  try:
114
+ if len(message) <= 10:
115
+ return history + [["", "Please provide a more detailed question (at least 10 characters)"]]
116
+
117
+ # Prepare tools prompt
118
+ tools_prompt = self._prepare_tools_prompt(message)
119
+
120
+ # Format conversation
121
+ conversation = [
122
+ {"role": "system", "content": "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning."},
123
+ *self._format_history(history),
124
+ {"role": "user", "content": message}
125
+ ]
126
 
127
  # Generate response
128
+ inputs = self.tokenizer.apply_chat_template(
129
+ conversation,
130
+ add_generation_prompt=True,
131
+ return_tensors="pt"
132
+ ).to(self.model.device)
133
+
134
+ outputs = self.model.generate(
135
+ inputs,
136
  max_new_tokens=1024,
137
+ temperature=0.7,
138
+ do_sample=True,
139
+ pad_token_id=self.tokenizer.eos_token_id
140
+ )
141
+
142
+ # Decode and clean response
143
+ response = self.tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
144
+ response = response.split("[TOOL_CALLS]")[0].strip()
145
 
 
146
  return history + [[message, response]]
147
+
148
  except Exception as e:
149
  logger.error(f"Chat error: {str(e)}")
150
  return history + [["", f"Error: {str(e)}"]]
151
 
152
+ def _prepare_tools_prompt(self, message: str) -> str:
153
+ """Prepare the tools prompt section"""
154
+ if not hasattr(self.tooluniverse, 'tools'):
155
+ return ""
156
+
157
+ tools_prompt = "\n\nYou have access to the following tools:\n"
158
+ for tool in self.tooluniverse.tools:
159
+ if tool['name'] not in self.special_tools:
160
+ tools_prompt += f"- {tool['name']}: {tool['description']}\n"
161
+
162
+ # Add special tools
163
+ tools_prompt += "\nSpecial tools:\n"
164
+ tools_prompt += "- Finish: Use when you have the final answer\n"
165
+ tools_prompt += "- Tool_RAG: Search for additional tools when needed\n"
166
+
167
+ return tools_prompt
168
+
169
+ def _format_history(self, history: List[List[str]]) -> List[Dict[str, str]]:
170
+ """Format chat history for the model"""
171
+ formatted = []
172
+ for user_msg, bot_msg in history:
173
+ formatted.append({"role": "user", "content": user_msg})
174
+ if bot_msg:
175
+ formatted.append({"role": "assistant", "content": bot_msg})
176
+ return formatted
177
+
178
  def create_interface() -> gr.Blocks:
179
  """Create the Gradio interface"""
180
+ agent = TxAgentWrapper()
181
 
182
  with gr.Blocks(
183
  title="TxAgent",
 
187
  ) as demo:
188
  gr.Markdown("""
189
  # 🧠 TxAgent: Therapeutic Reasoning AI
190
+ ### (Loading from Hugging Face Hub)
191
  """)
192
 
193
  with gr.Row():
 
210
  inputs=msg
211
  )
212
 
213
+ def wrapper_initialize():
214
+ status = agent.initialize()
 
215
  return status, gr.update(interactive=False)
216
 
217
  init_btn.click(
 
220
  )
221
 
222
  msg.submit(
223
+ fn=agent.chat,
224
  inputs=[msg, chatbot],
225
  outputs=chatbot
226
  ).then(