Ali2206 commited on
Commit
709aba9
·
verified ·
1 Parent(s): 88ae38d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -146
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import random
 
 
2
  import os
3
  import torch
4
  import logging
@@ -19,7 +21,6 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
19
  os.environ["MKL_THREADING_LAYER"] = "GNU"
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
 
22
- # Configuration
23
  CONFIG = {
24
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
25
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
@@ -33,94 +34,98 @@ CONFIG = {
33
  }
34
  }
35
 
36
- DESCRIPTION = '''
37
- <div>
38
- <h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools</h1>
39
- </div>
40
- '''
41
-
42
- INTRO = """
43
- Precision therapeutics require multimodal adaptive models that provide personalized treatment recommendations.
44
- We introduce TxAgent, an AI agent that leverages multi-step reasoning and real-time biomedical knowledge
45
- retrieval across a toolbox of 211 expert-curated tools to navigate complex drug interactions,
46
- contraindications, and patient-specific treatment strategies, delivering evidence-grounded therapeutic decisions.
47
  """
48
 
49
- LICENSE = """
50
- We welcome your feedback and suggestions to enhance your experience with TxAgent, and if you're interested
51
- in collaboration, please email Marinka Zitnik and Shanghua Gao.
52
-
53
- ### Medical Advice Disclaimer
54
- DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE
55
- The information, including but not limited to, text, graphics, images and other material contained on this
56
- website are for informational purposes only. No material on this site is intended to be a substitute for
57
- professional medical advice, diagnosis or treatment.
58
- """
59
-
60
- PLACEHOLDER = """
61
- <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
62
- <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">TxAgent</h1>
63
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using TxAgent:</p>
64
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️ (top-right) to remove previous context before submitting a new question.</p>
65
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p>
66
- </div>
67
- """
68
-
69
- def safe_load_embeddings(filepath: str):
70
- """Handle embedding loading with fallbacks"""
71
  try:
72
  return torch.load(filepath, weights_only=True)
73
  except Exception as e:
74
- logger.warning(f"Secure load failed, trying without weights_only: {str(e)}")
75
  try:
76
  return torch.load(filepath, weights_only=False)
77
  except Exception as e:
78
  logger.error(f"Failed to load embeddings: {str(e)}")
79
  return None
80
 
81
- def get_tools_from_universe(tooluniverse):
82
- """Flexible tool extraction from ToolUniverse"""
83
- if hasattr(tooluniverse, 'get_all_tools'):
84
- return tooluniverse.get_all_tools()
85
- elif hasattr(tooluniverse, 'tools'):
86
- return tooluniverse.tools
87
- elif hasattr(tooluniverse, 'list_tools'):
88
- return tooluniverse.list_tools()
89
- else:
90
- logger.error("Could not find any tool access method in ToolUniverse")
91
- # Try to load from files directly as fallback
92
- tools = []
93
- for tool_file in CONFIG["tool_files"].values():
94
- if os.path.exists(tool_file):
95
- with open(tool_file, 'r') as f:
96
- tools.extend(json.load(f))
97
- return tools if tools else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  def prepare_tool_files():
100
- """Ensure tool files exist and are populated"""
101
  os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
102
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
103
- logger.info("Generating tool list...")
104
  try:
105
  tu = ToolUniverse()
106
- tools = get_tools_from_universe(tu)
107
- if tools:
108
- with open(CONFIG["tool_files"]["new_tool"], "w") as f:
109
- json.dump(tools, f, indent=2)
110
- logger.info(f"Saved {len(tools)} tools")
111
  else:
112
- logger.error("No tools could be loaded")
 
 
 
 
 
113
  except Exception as e:
114
- logger.error(f"Tool file preparation failed: {str(e)}")
115
 
116
  def create_agent():
117
- """Create and initialize the TxAgent with robust error handling"""
118
  prepare_tool_files()
119
-
120
  try:
121
  agent = TxAgent(
122
- model_name=CONFIG["model_name"],
123
- rag_model_name=CONFIG["rag_model_name"],
124
  tool_files_dict=CONFIG["tool_files"],
125
  force_finish=True,
126
  enable_checker=True,
@@ -131,100 +136,71 @@ def create_agent():
131
  agent.init_model()
132
  return agent
133
  except Exception as e:
134
- logger.error(f"Agent creation failed: {str(e)}")
135
  raise
136
 
137
- def format_response(history, message):
138
- """Properly format responses for Gradio Chatbot"""
139
- if isinstance(message, (str, dict)):
140
- return history + [[None, str(message)]]
141
- elif hasattr(message, '__iter__'):
142
- full_response = ""
143
- for chunk in message:
144
- if isinstance(chunk, dict):
145
- full_response += chunk.get("content", "")
146
- else:
147
- full_response += str(chunk)
148
- return history + [[None, full_response]]
149
- return history + [[None, str(message)]]
150
 
151
- def create_demo(agent):
152
- """Create the Gradio interface with proper message handling"""
153
- with gr.Blocks() as demo:
154
- gr.Markdown(DESCRIPTION)
155
- gr.Markdown(INTRO)
156
-
157
- chatbot = gr.Chatbot(
158
- height=800,
159
- label='TxAgent',
160
- show_copy_button=True,
161
- bubble_full_width=False
 
162
  )
163
-
164
- msg = gr.Textbox(label="Input", placeholder="Type your question...")
165
- clear = gr.ClearButton([msg, chatbot])
166
-
167
- def respond(message, chat_history):
168
- try:
169
- # Convert Gradio history to agent format
170
- agent_history = []
171
- for user_msg, bot_msg in chat_history:
172
- if user_msg:
173
- agent_history.append({"role": "user", "content": user_msg})
174
- if bot_msg:
175
- agent_history.append({"role": "assistant", "content": bot_msg})
176
-
177
- # Get response from agent
178
- response = agent.run_gradio_chat(
179
- agent_history + [{"role": "user", "content": message}],
180
- temperature=0.3,
181
- max_new_tokens=1024,
182
- max_tokens=81920,
183
- multi_agent=False,
184
- conversation=[],
185
- max_round=30
186
- )
187
-
188
- # Format the response properly
189
- full_response = ""
190
- for chunk in response:
191
- if isinstance(chunk, dict):
192
- full_response += chunk.get("content", "")
193
- else:
194
- full_response += str(chunk)
195
-
196
- return chat_history + [(message, full_response)]
197
-
198
- except Exception as e:
199
- logger.error(f"Error in response handling: {str(e)}")
200
- return chat_history + [(message, f"Error: {str(e)}")]
201
-
202
- msg.submit(respond, [msg, chatbot], [chatbot])
203
- clear.click(lambda: [], None, [chatbot])
204
-
205
- # Add settings section
206
- with gr.Accordion("Settings", open=False):
207
- gr.Markdown("Adjust model parameters here")
208
-
209
- with gr.Row():
210
- temperature = gr.Slider(0, 1, value=0.3, label="Temperature")
211
- max_new_tokens = gr.Slider(128, 4096, value=1024, step=1, label="Max New Tokens")
212
-
213
- with gr.Row():
214
- max_tokens = gr.Slider(128, 32000, value=81920, step=1, label="Max Tokens")
215
- max_round = gr.Slider(1, 50, value=30, step=1, label="Max Round")
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  return demo
218
 
219
  def main():
220
- """Main application entry point"""
221
  try:
 
222
  agent = create_agent()
223
  demo = create_demo(agent)
224
- demo.launch(server_name="0.0.0.0", server_port=7860)
225
  except Exception as e:
226
  logger.error(f"Application failed to start: {str(e)}")
227
  raise
228
 
229
  if __name__ == "__main__":
230
- main()
 
1
  import random
2
+ import datetime
3
+ import sys
4
  import os
5
  import torch
6
  import logging
 
21
  os.environ["MKL_THREADING_LAYER"] = "GNU"
22
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
 
 
24
  CONFIG = {
25
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
26
  "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
34
  }
35
  }
36
 
37
+ chat_css = """
38
+ .gr-button { font-size: 20px !important; }
39
+ .gr-button svg { width: 32px !important; height: 32px !important; }
 
 
 
 
 
 
 
 
40
  """
41
 
42
+ def safe_load_embeddings(filepath: str) -> any:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  try:
44
  return torch.load(filepath, weights_only=True)
45
  except Exception as e:
46
+ logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
47
  try:
48
  return torch.load(filepath, weights_only=False)
49
  except Exception as e:
50
  logger.error(f"Failed to load embeddings: {str(e)}")
51
  return None
52
 
53
+ def patch_embedding_loading():
54
+ try:
55
+ from txagent.toolrag import ToolRAGModel
56
+
57
+ def patched_load(self, tooluniverse):
58
+ try:
59
+ if not os.path.exists(CONFIG["embedding_filename"]):
60
+ logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
61
+ return False
62
+
63
+ self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
64
+
65
+ if hasattr(tooluniverse, 'get_all_tools'):
66
+ tools = tooluniverse.get_all_tools()
67
+ elif hasattr(tooluniverse, 'tools'):
68
+ tools = tooluniverse.tools
69
+ else:
70
+ logger.error("No method found to access tools from ToolUniverse")
71
+ return False
72
+
73
+ current_count = len(tools)
74
+ embedding_count = len(self.tool_desc_embedding)
75
+
76
+ if current_count != embedding_count:
77
+ logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
78
+
79
+ if current_count < embedding_count:
80
+ self.tool_desc_embedding = self.tool_desc_embedding[:current_count]
81
+ logger.info(f"Truncated embeddings to match {current_count} tools")
82
+ else:
83
+ last_embedding = self.tool_desc_embedding[-1]
84
+ padding = [last_embedding] * (current_count - embedding_count)
85
+ self.tool_desc_embedding = torch.cat([self.tool_desc_embedding] + padding)
86
+ logger.info(f"Padded embeddings to match {current_count} tools")
87
+
88
+ return True
89
+
90
+ except Exception as e:
91
+ logger.error(f"Failed to load embeddings: {str(e)}")
92
+ return False
93
+
94
+ ToolRAGModel.load_tool_desc_embedding = patched_load
95
+ logger.info("Successfully patched embedding loading")
96
+
97
+ except Exception as e:
98
+ logger.error(f"Failed to patch embedding loading: {str(e)}")
99
+ raise
100
 
101
  def prepare_tool_files():
 
102
  os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
103
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
104
+ logger.info("Generating tool list using ToolUniverse...")
105
  try:
106
  tu = ToolUniverse()
107
+ if hasattr(tu, 'get_all_tools'):
108
+ tools = tu.get_all_tools()
109
+ elif hasattr(tu, 'tools'):
110
+ tools = tu.tools
 
111
  else:
112
+ tools = []
113
+ logger.error("Could not access tools from ToolUniverse")
114
+
115
+ with open(CONFIG["tool_files"]["new_tool"], "w") as f:
116
+ json.dump(tools, f, indent=2)
117
+ logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
118
  except Exception as e:
119
+ logger.error(f"Failed to prepare tool files: {str(e)}")
120
 
121
  def create_agent():
122
+ patch_embedding_loading()
123
  prepare_tool_files()
124
+
125
  try:
126
  agent = TxAgent(
127
+ CONFIG["model_name"],
128
+ CONFIG["rag_model_name"],
129
  tool_files_dict=CONFIG["tool_files"],
130
  force_finish=True,
131
  enable_checker=True,
 
136
  agent.init_model()
137
  return agent
138
  except Exception as e:
139
+ logger.error(f"Failed to create agent: {str(e)}")
140
  raise
141
 
142
+ def respond(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
143
+ updated_history = history + [{"role": "user", "content": message}]
144
+ print("\n==== DEBUG ====")
145
+ print("User Message:", message)
146
+ print("Full History:", updated_history)
147
+ print("================\n")
 
 
 
 
 
 
 
148
 
149
+ try:
150
+ # Ensure correct format for run_gradio_chat
151
+ formatted_history = [(m["role"], m["content"]) for m in updated_history]
152
+
153
+ response_generator = agent.run_gradio_chat(
154
+ formatted_history,
155
+ temperature,
156
+ max_new_tokens,
157
+ max_tokens,
158
+ multi_agent,
159
+ conversation,
160
+ max_round
161
  )
162
+ except Exception as e:
163
+ return history + [{"role": "user", "content": message}, {"role": "assistant", "content": f"Error: {str(e)}"}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ collected = ""
166
+ for chunk in response_generator:
167
+ if isinstance(chunk, dict):
168
+ collected += chunk.get("content", "")
169
+ else:
170
+ collected += str(chunk)
171
+
172
+ return history + [{"role": "user", "content": message}, {"role": "assistant", "content": collected}]
173
+
174
+ def create_demo(agent):
175
+ with gr.Blocks(css=chat_css) as demo:
176
+ chatbot = gr.Chatbot(label="TxAgent", type="messages")
177
+ with gr.Row():
178
+ msg = gr.Textbox(label="Your question")
179
+ with gr.Row():
180
+ temp = gr.Slider(0, 1, value=0.3, label="Temperature")
181
+ max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
182
+ max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens")
183
+ max_rounds = gr.Slider(1, 30, value=30, label="Max Rounds")
184
+ multi_agent = gr.Checkbox(label="Multi-Agent Mode")
185
+ with gr.Row():
186
+ submit = gr.Button("Ask TxAgent")
187
+
188
+ submit.click(
189
+ respond,
190
+ inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds],
191
+ outputs=[chatbot]
192
+ )
193
  return demo
194
 
195
  def main():
 
196
  try:
197
+ global agent
198
  agent = create_agent()
199
  demo = create_demo(agent)
200
+ demo.launch()
201
  except Exception as e:
202
  logger.error(f"Application failed to start: {str(e)}")
203
  raise
204
 
205
  if __name__ == "__main__":
206
+ main()