Ali2206 commited on
Commit
88ae38d
·
verified ·
1 Parent(s): 31db4cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -105
app.py CHANGED
@@ -1,6 +1,4 @@
1
  import random
2
- import datetime
3
- import sys
4
  import os
5
  import torch
6
  import logging
@@ -17,7 +15,6 @@ logging.basicConfig(
17
  )
18
  logger = logging.getLogger(__name__)
19
 
20
- # Determine the directory where the current file is located
21
  current_dir = os.path.dirname(os.path.abspath(__file__))
22
  os.environ["MKL_THREADING_LAYER"] = "GNU"
23
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -36,98 +33,94 @@ CONFIG = {
36
  }
37
  }
38
 
39
- chat_css = """
40
- .gr-button { font-size: 20px !important; }
41
- .gr-button svg { width: 32px !important; height: 32px !important; }
 
 
 
 
 
 
 
 
42
  """
43
 
44
- def safe_load_embeddings(filepath: str) -> any:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  try:
46
  return torch.load(filepath, weights_only=True)
47
  except Exception as e:
48
- logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
49
  try:
50
  return torch.load(filepath, weights_only=False)
51
  except Exception as e:
52
  logger.error(f"Failed to load embeddings: {str(e)}")
53
  return None
54
 
55
- def patch_embedding_loading():
56
- try:
57
- from txagent.toolrag import ToolRAGModel
58
-
59
- def patched_load(self, tooluniverse):
60
- try:
61
- if not os.path.exists(CONFIG["embedding_filename"]):
62
- logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
63
- return False
64
-
65
- self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
66
-
67
- if hasattr(tooluniverse, 'get_all_tools'):
68
- tools = tooluniverse.get_all_tools()
69
- elif hasattr(tooluniverse, 'tools'):
70
- tools = tooluniverse.tools
71
- else:
72
- logger.error("No method found to access tools from ToolUniverse")
73
- return False
74
-
75
- current_count = len(tools)
76
- embedding_count = len(self.tool_desc_embedding)
77
-
78
- if current_count != embedding_count:
79
- logger.warning(f"Tool count mismatch (tools: {current_count}, embeddings: {embedding_count})")
80
-
81
- if current_count < embedding_count:
82
- self.tool_desc_embedding = self.tool_desc_embedding[:current_count]
83
- logger.info(f"Truncated embeddings to match {current_count} tools")
84
- else:
85
- last_embedding = self.tool_desc_embedding[-1]
86
- padding = [last_embedding] * (current_count - embedding_count)
87
- self.tool_desc_embedding = torch.cat([self.tool_desc_embedding] + padding)
88
- logger.info(f"Padded embeddings to match {current_count} tools")
89
-
90
- return True
91
-
92
- except Exception as e:
93
- logger.error(f"Failed to load embeddings: {str(e)}")
94
- return False
95
-
96
- ToolRAGModel.load_tool_desc_embedding = patched_load
97
- logger.info("Successfully patched embedding loading")
98
-
99
- except Exception as e:
100
- logger.error(f"Failed to patch embedding loading: {str(e)}")
101
- raise
102
 
103
  def prepare_tool_files():
 
104
  os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
105
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
106
- logger.info("Generating tool list using ToolUniverse...")
107
  try:
108
  tu = ToolUniverse()
109
- if hasattr(tu, 'get_all_tools'):
110
- tools = tu.get_all_tools()
111
- elif hasattr(tu, 'tools'):
112
- tools = tu.tools
 
113
  else:
114
- tools = []
115
- logger.error("Could not access tools from ToolUniverse")
116
-
117
- with open(CONFIG["tool_files"]["new_tool"], "w") as f:
118
- json.dump(tools, f, indent=2)
119
- logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
120
  except Exception as e:
121
- logger.error(f"Failed to prepare tool files: {str(e)}")
122
 
123
  def create_agent():
124
- patch_embedding_loading()
125
  prepare_tool_files()
126
-
127
  try:
128
  agent = TxAgent(
129
- CONFIG["model_name"],
130
- CONFIG["rag_model_name"],
131
  tool_files_dict=CONFIG["tool_files"],
132
  force_finish=True,
133
  enable_checker=True,
@@ -138,51 +131,97 @@ def create_agent():
138
  agent.init_model()
139
  return agent
140
  except Exception as e:
141
- logger.error(f"Failed to create agent: {str(e)}")
142
  raise
143
 
144
- def respond(message, history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
145
- updated_history = history + [{"role": "user", "content": message}]
146
- response_generator = agent.run_gradio_chat(
147
- updated_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round
148
- )
149
-
150
- collected = ""
151
- for chunk in response_generator:
152
- if isinstance(chunk, dict):
153
- collected += chunk.get("content", "")
154
- else:
155
- collected += str(chunk)
156
-
157
- return history + [{"role": "user", "content": message}, {"role": "assistant", "content": collected}]
158
 
159
  def create_demo(agent):
160
- with gr.Blocks(css=chat_css) as demo:
161
- chatbot = gr.Chatbot(label="TxAgent", type="messages")
162
- with gr.Row():
163
- msg = gr.Textbox(label="Your question")
164
- with gr.Row():
165
- temp = gr.Slider(0, 1, value=0.3, label="Temperature")
166
- max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
167
- max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens")
168
- max_rounds = gr.Slider(1, 30, value=30, label="Max Rounds")
169
- multi_agent = gr.Checkbox(label="Multi-Agent Mode")
170
- with gr.Row():
171
- submit = gr.Button("Ask TxAgent")
172
-
173
- submit.click(
174
- respond,
175
- inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds],
176
- outputs=[chatbot]
177
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  return demo
179
 
180
  def main():
 
181
  try:
182
- global agent
183
  agent = create_agent()
184
  demo = create_demo(agent)
185
- demo.launch()
186
  except Exception as e:
187
  logger.error(f"Application failed to start: {str(e)}")
188
  raise
 
1
  import random
 
 
2
  import os
3
  import torch
4
  import logging
 
15
  )
16
  logger = logging.getLogger(__name__)
17
 
 
18
  current_dir = os.path.dirname(os.path.abspath(__file__))
19
  os.environ["MKL_THREADING_LAYER"] = "GNU"
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
33
  }
34
  }
35
 
36
+ DESCRIPTION = '''
37
+ <div>
38
+ <h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools</h1>
39
+ </div>
40
+ '''
41
+
42
+ INTRO = """
43
+ Precision therapeutics require multimodal adaptive models that provide personalized treatment recommendations.
44
+ We introduce TxAgent, an AI agent that leverages multi-step reasoning and real-time biomedical knowledge
45
+ retrieval across a toolbox of 211 expert-curated tools to navigate complex drug interactions,
46
+ contraindications, and patient-specific treatment strategies, delivering evidence-grounded therapeutic decisions.
47
  """
48
 
49
+ LICENSE = """
50
+ We welcome your feedback and suggestions to enhance your experience with TxAgent, and if you're interested
51
+ in collaboration, please email Marinka Zitnik and Shanghua Gao.
52
+
53
+ ### Medical Advice Disclaimer
54
+ DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE
55
+ The information, including but not limited to, text, graphics, images and other material contained on this
56
+ website are for informational purposes only. No material on this site is intended to be a substitute for
57
+ professional medical advice, diagnosis or treatment.
58
+ """
59
+
60
+ PLACEHOLDER = """
61
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
62
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">TxAgent</h1>
63
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using TxAgent:</p>
64
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️ (top-right) to remove previous context before submitting a new question.</p>
65
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p>
66
+ </div>
67
+ """
68
+
69
+ def safe_load_embeddings(filepath: str):
70
+ """Handle embedding loading with fallbacks"""
71
  try:
72
  return torch.load(filepath, weights_only=True)
73
  except Exception as e:
74
+ logger.warning(f"Secure load failed, trying without weights_only: {str(e)}")
75
  try:
76
  return torch.load(filepath, weights_only=False)
77
  except Exception as e:
78
  logger.error(f"Failed to load embeddings: {str(e)}")
79
  return None
80
 
81
+ def get_tools_from_universe(tooluniverse):
82
+ """Flexible tool extraction from ToolUniverse"""
83
+ if hasattr(tooluniverse, 'get_all_tools'):
84
+ return tooluniverse.get_all_tools()
85
+ elif hasattr(tooluniverse, 'tools'):
86
+ return tooluniverse.tools
87
+ elif hasattr(tooluniverse, 'list_tools'):
88
+ return tooluniverse.list_tools()
89
+ else:
90
+ logger.error("Could not find any tool access method in ToolUniverse")
91
+ # Try to load from files directly as fallback
92
+ tools = []
93
+ for tool_file in CONFIG["tool_files"].values():
94
+ if os.path.exists(tool_file):
95
+ with open(tool_file, 'r') as f:
96
+ tools.extend(json.load(f))
97
+ return tools if tools else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  def prepare_tool_files():
100
+ """Ensure tool files exist and are populated"""
101
  os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
102
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
103
+ logger.info("Generating tool list...")
104
  try:
105
  tu = ToolUniverse()
106
+ tools = get_tools_from_universe(tu)
107
+ if tools:
108
+ with open(CONFIG["tool_files"]["new_tool"], "w") as f:
109
+ json.dump(tools, f, indent=2)
110
+ logger.info(f"Saved {len(tools)} tools")
111
  else:
112
+ logger.error("No tools could be loaded")
 
 
 
 
 
113
  except Exception as e:
114
+ logger.error(f"Tool file preparation failed: {str(e)}")
115
 
116
  def create_agent():
117
+ """Create and initialize the TxAgent with robust error handling"""
118
  prepare_tool_files()
119
+
120
  try:
121
  agent = TxAgent(
122
+ model_name=CONFIG["model_name"],
123
+ rag_model_name=CONFIG["rag_model_name"],
124
  tool_files_dict=CONFIG["tool_files"],
125
  force_finish=True,
126
  enable_checker=True,
 
131
  agent.init_model()
132
  return agent
133
  except Exception as e:
134
+ logger.error(f"Agent creation failed: {str(e)}")
135
  raise
136
 
137
+ def format_response(history, message):
138
+ """Properly format responses for Gradio Chatbot"""
139
+ if isinstance(message, (str, dict)):
140
+ return history + [[None, str(message)]]
141
+ elif hasattr(message, '__iter__'):
142
+ full_response = ""
143
+ for chunk in message:
144
+ if isinstance(chunk, dict):
145
+ full_response += chunk.get("content", "")
146
+ else:
147
+ full_response += str(chunk)
148
+ return history + [[None, full_response]]
149
+ return history + [[None, str(message)]]
 
150
 
151
  def create_demo(agent):
152
+ """Create the Gradio interface with proper message handling"""
153
+ with gr.Blocks() as demo:
154
+ gr.Markdown(DESCRIPTION)
155
+ gr.Markdown(INTRO)
156
+
157
+ chatbot = gr.Chatbot(
158
+ height=800,
159
+ label='TxAgent',
160
+ show_copy_button=True,
161
+ bubble_full_width=False
 
 
 
 
 
 
 
162
  )
163
+
164
+ msg = gr.Textbox(label="Input", placeholder="Type your question...")
165
+ clear = gr.ClearButton([msg, chatbot])
166
+
167
+ def respond(message, chat_history):
168
+ try:
169
+ # Convert Gradio history to agent format
170
+ agent_history = []
171
+ for user_msg, bot_msg in chat_history:
172
+ if user_msg:
173
+ agent_history.append({"role": "user", "content": user_msg})
174
+ if bot_msg:
175
+ agent_history.append({"role": "assistant", "content": bot_msg})
176
+
177
+ # Get response from agent
178
+ response = agent.run_gradio_chat(
179
+ agent_history + [{"role": "user", "content": message}],
180
+ temperature=0.3,
181
+ max_new_tokens=1024,
182
+ max_tokens=81920,
183
+ multi_agent=False,
184
+ conversation=[],
185
+ max_round=30
186
+ )
187
+
188
+ # Format the response properly
189
+ full_response = ""
190
+ for chunk in response:
191
+ if isinstance(chunk, dict):
192
+ full_response += chunk.get("content", "")
193
+ else:
194
+ full_response += str(chunk)
195
+
196
+ return chat_history + [(message, full_response)]
197
+
198
+ except Exception as e:
199
+ logger.error(f"Error in response handling: {str(e)}")
200
+ return chat_history + [(message, f"Error: {str(e)}")]
201
+
202
+ msg.submit(respond, [msg, chatbot], [chatbot])
203
+ clear.click(lambda: [], None, [chatbot])
204
+
205
+ # Add settings section
206
+ with gr.Accordion("Settings", open=False):
207
+ gr.Markdown("Adjust model parameters here")
208
+
209
+ with gr.Row():
210
+ temperature = gr.Slider(0, 1, value=0.3, label="Temperature")
211
+ max_new_tokens = gr.Slider(128, 4096, value=1024, step=1, label="Max New Tokens")
212
+
213
+ with gr.Row():
214
+ max_tokens = gr.Slider(128, 32000, value=81920, step=1, label="Max Tokens")
215
+ max_round = gr.Slider(1, 50, value=30, step=1, label="Max Round")
216
+
217
  return demo
218
 
219
  def main():
220
+ """Main application entry point"""
221
  try:
 
222
  agent = create_agent()
223
  demo = create_demo(agent)
224
+ demo.launch(server_name="0.0.0.0", server_port=7860)
225
  except Exception as e:
226
  logger.error(f"Application failed to start: {str(e)}")
227
  raise