Ali2206 commited on
Commit
5ffaf72
·
verified ·
1 Parent(s): 253ca2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -93
app.py CHANGED
@@ -1,23 +1,21 @@
1
  import os
2
- import torch
3
  import json
 
4
  import logging
5
  import gradio as gr
6
  from importlib.resources import files
7
  from txagent import TxAgent
8
  from tooluniverse import ToolUniverse
9
 
10
- # Setup logging
11
  logging.basicConfig(
12
  level=logging.INFO,
13
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
14
  )
15
  logger = logging.getLogger(__name__)
16
 
17
- # Env vars
18
- current_dir = os.path.dirname(os.path.abspath(__file__))
19
  os.environ["MKL_THREADING_LAYER"] = "GNU"
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
21
 
22
  CONFIG = {
23
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
@@ -32,74 +30,46 @@ CONFIG = {
32
  }
33
  }
34
 
35
- chat_css = """
36
- .gr-button { font-size: 20px !important; }
37
- .gr-button svg { width: 32px !important; height: 32px !important; }
38
- """
39
-
40
- def safe_load_embeddings(filepath: str) -> any:
41
  try:
42
  return torch.load(filepath, weights_only=True)
43
  except Exception as e:
44
- logger.warning(f"Secure load failed, trying with weights_only=False: {str(e)}")
45
  try:
46
  return torch.load(filepath, weights_only=False)
47
  except Exception as e:
48
- logger.error(f"Failed to load embeddings: {str(e)}")
49
  return None
50
 
51
  def patch_embedding_loading():
52
- try:
53
- from txagent.toolrag import ToolRAGModel
54
-
55
- def patched_load(self, tooluniverse):
56
- try:
57
- if not os.path.exists(CONFIG["embedding_filename"]):
58
- logger.error(f"Embedding file not found: {CONFIG['embedding_filename']}")
59
- return False
60
-
61
- self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
62
-
63
- if hasattr(tooluniverse, 'get_all_tools'):
64
- tools = tooluniverse.get_all_tools()
65
- elif hasattr(tooluniverse, 'tools'):
66
- tools = tooluniverse.tools
67
- else:
68
- logger.error("No method found to access tools from ToolUniverse")
69
- return False
70
-
71
- if len(tools) != len(self.tool_desc_embedding):
72
- logger.warning("Tool count and embedding count mismatch.")
73
- if len(tools) < len(self.tool_desc_embedding):
74
- self.tool_desc_embedding = self.tool_desc_embedding[:len(tools)]
75
- else:
76
- last_emb = self.tool_desc_embedding[-1]
77
- padding = [last_emb] * (len(tools) - len(self.tool_desc_embedding))
78
- self.tool_desc_embedding = torch.cat([self.tool_desc_embedding] + padding)
79
-
80
- return True
81
-
82
- except Exception as e:
83
- logger.error(f"Failed to load embeddings: {str(e)}")
84
  return False
 
85
 
86
- ToolRAGModel.load_tool_desc_embedding = patched_load
87
- logger.info("Successfully patched ToolRAGModel")
 
 
 
 
 
 
88
 
89
- except Exception as e:
90
- logger.error(f"Failed to patch embedding loader: {str(e)}")
91
 
92
  def prepare_tool_files():
93
  os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
94
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
95
  try:
96
  tu = ToolUniverse()
97
- tools = tu.get_all_tools() if hasattr(tu, 'get_all_tools') else getattr(tu, 'tools', [])
98
  with open(CONFIG["tool_files"]["new_tool"], "w") as f:
99
  json.dump(tools, f, indent=2)
100
- logger.info(f"Saved {len(tools)} tools to {CONFIG['tool_files']['new_tool']}")
101
  except Exception as e:
102
- logger.error(f"Failed to prepare tool files: {str(e)}")
103
 
104
  def create_agent():
105
  patch_embedding_loading()
@@ -112,83 +82,69 @@ def create_agent():
112
  force_finish=True,
113
  enable_checker=True,
114
  step_rag_num=10,
115
- seed=100,
116
- additional_default_tools=['DirectResponse', 'RequireClarification']
117
  )
118
  agent.init_model()
119
  return agent
120
  except Exception as e:
121
- logger.error(f"Failed to create TxAgent: {str(e)}")
122
  raise
123
 
124
- # ✅ GRADIO 5.x-compatible message format
125
  def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
126
  if not isinstance(msg, str) or len(msg.strip()) <= 10:
127
- return chat_history + [{"role": "assistant", "content": "Hi, I am TxAgent. Please provide a valid question with more than 10 characters."}]
128
 
129
- chat_history = chat_history + [{"role": "user", "content": msg.strip()}]
130
-
131
- print("\n==== DEBUG ====")
132
- print("User Message:", msg)
133
- print("Chat History:", chat_history)
134
- print("================\n")
135
 
136
  try:
137
- formatted_history = [(m["role"], m["content"]) for m in chat_history]
138
-
139
  response_generator = agent.run_gradio_chat(
140
- formatted_history,
141
- temperature,
142
- max_new_tokens,
143
- max_tokens,
144
- multi_agent,
145
- conversation,
146
- max_round
 
 
 
 
147
  )
148
-
149
  collected = ""
150
  for chunk in response_generator:
151
- if isinstance(chunk, dict):
152
- collected += chunk.get("content", "")
153
- else:
154
- collected += str(chunk)
155
-
156
  chat_history.append({"role": "assistant", "content": collected})
157
  except Exception as e:
158
- chat_history.append({"role": "assistant", "content": f"Error: {str(e)}"})
159
-
160
  return chat_history
161
 
162
  def create_demo(agent):
163
- with gr.Blocks(css=chat_css) as demo:
164
  chatbot = gr.Chatbot(label="TxAgent", type="messages", render_markdown=True)
165
- msg = gr.Textbox(label="Your question", placeholder="Type your biomedical query...", scale=6)
166
  with gr.Row():
167
  temp = gr.Slider(0, 1, value=0.3, label="Temperature")
168
  max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
169
  max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens")
170
  max_rounds = gr.Slider(1, 30, value=30, label="Max Rounds")
171
  multi_agent = gr.Checkbox(label="Multi-Agent Mode")
172
- with gr.Row():
173
- submit = gr.Button("Ask TxAgent")
174
-
175
  submit.click(
176
  respond,
177
  inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds],
178
  outputs=[chatbot]
179
  )
180
-
181
  return demo
182
 
183
  def main():
184
- try:
185
- global agent
186
- agent = create_agent()
187
- demo = create_demo(agent)
188
- demo.launch(share=False) # Set to True to get a public link
189
- except Exception as e:
190
- logger.error(f"Application failed to start: {str(e)}")
191
- raise
192
 
193
  if __name__ == "__main__":
194
  main()
 
1
  import os
 
2
  import json
3
+ import torch
4
  import logging
5
  import gradio as gr
6
  from importlib.resources import files
7
  from txagent import TxAgent
8
  from tooluniverse import ToolUniverse
9
 
 
10
  logging.basicConfig(
11
  level=logging.INFO,
12
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
13
  )
14
  logger = logging.getLogger(__name__)
15
 
 
 
16
  os.environ["MKL_THREADING_LAYER"] = "GNU"
17
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+ current_dir = os.path.dirname(os.path.abspath(__file__))
19
 
20
  CONFIG = {
21
  "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
 
30
  }
31
  }
32
 
33
+ def safe_load_embeddings(filepath):
 
 
 
 
 
34
  try:
35
  return torch.load(filepath, weights_only=True)
36
  except Exception as e:
37
+ logger.warning(f"Retrying with weights_only=False due to: {e}")
38
  try:
39
  return torch.load(filepath, weights_only=False)
40
  except Exception as e:
41
+ logger.error(f"Failed to load embeddings: {e}")
42
  return None
43
 
44
  def patch_embedding_loading():
45
+ from txagent.toolrag import ToolRAGModel
46
+ def patched_load(self, tooluniverse):
47
+ try:
48
+ if not os.path.exists(CONFIG["embedding_filename"]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  return False
50
+ self.tool_desc_embedding = safe_load_embeddings(CONFIG["embedding_filename"])
51
 
52
+ tools = tooluniverse.get_all_tools() if hasattr(tooluniverse, "get_all_tools") else getattr(tooluniverse, "tools", [])
53
+ if len(tools) != len(self.tool_desc_embedding):
54
+ logger.warning("Tool count mismatch.")
55
+ self.tool_desc_embedding = self.tool_desc_embedding[:len(tools)]
56
+ return True
57
+ except Exception as e:
58
+ logger.error(f"Embedding load failed: {e}")
59
+ return False
60
 
61
+ ToolRAGModel.load_tool_desc_embedding = patched_load
 
62
 
63
  def prepare_tool_files():
64
  os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
65
  if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
66
  try:
67
  tu = ToolUniverse()
68
+ tools = tu.get_all_tools() if hasattr(tu, "get_all_tools") else getattr(tu, "tools", [])
69
  with open(CONFIG["tool_files"]["new_tool"], "w") as f:
70
  json.dump(tools, f, indent=2)
 
71
  except Exception as e:
72
+ logger.error(f"Tool generation failed: {e}")
73
 
74
  def create_agent():
75
  patch_embedding_loading()
 
82
  force_finish=True,
83
  enable_checker=True,
84
  step_rag_num=10,
85
+ seed=42,
86
+ additional_default_tools=["DirectResponse", "RequireClarification"]
87
  )
88
  agent.init_model()
89
  return agent
90
  except Exception as e:
91
+ logger.error(f"Agent initialization failed: {e}")
92
  raise
93
 
94
+ # ✅ FIXED: Proper message formatting
95
  def respond(msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
96
  if not isinstance(msg, str) or len(msg.strip()) <= 10:
97
+ return chat_history + [{"role": "assistant", "content": "Hi, I am TxAgent. Please provide a valid message longer than 10 characters."}]
98
 
99
+ message = msg.strip()
100
+ chat_history.append({"role": "user", "content": message})
101
+ formatted_history = [(m["role"], m["content"]) for m in chat_history]
 
 
 
102
 
103
  try:
 
 
104
  response_generator = agent.run_gradio_chat(
105
+ message=message,
106
+ history=formatted_history,
107
+ temperature=temperature,
108
+ max_new_tokens=max_new_tokens,
109
+ max_token=max_tokens,
110
+ call_agent=multi_agent,
111
+ conversation=conversation,
112
+ max_round=max_round,
113
+ seed=42,
114
+ call_agent_level=None,
115
+ sub_agent_task=None
116
  )
 
117
  collected = ""
118
  for chunk in response_generator:
119
+ collected += chunk.get("content", "") if isinstance(chunk, dict) else str(chunk)
 
 
 
 
120
  chat_history.append({"role": "assistant", "content": collected})
121
  except Exception as e:
122
+ chat_history.append({"role": "assistant", "content": f"Error: {e}"})
 
123
  return chat_history
124
 
125
  def create_demo(agent):
126
+ with gr.Blocks(css=".gr-button { font-size: 18px !important; }") as demo:
127
  chatbot = gr.Chatbot(label="TxAgent", type="messages", render_markdown=True)
128
+ msg = gr.Textbox(label="Your question", placeholder="Ask a biomedical question...", scale=6)
129
  with gr.Row():
130
  temp = gr.Slider(0, 1, value=0.3, label="Temperature")
131
  max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
132
  max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens")
133
  max_rounds = gr.Slider(1, 30, value=30, label="Max Rounds")
134
  multi_agent = gr.Checkbox(label="Multi-Agent Mode")
135
+ submit = gr.Button("Ask TxAgent")
 
 
136
  submit.click(
137
  respond,
138
  inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds],
139
  outputs=[chatbot]
140
  )
 
141
  return demo
142
 
143
  def main():
144
+ global agent
145
+ agent = create_agent()
146
+ demo = create_demo(agent)
147
+ demo.launch(share=False)
 
 
 
 
148
 
149
  if __name__ == "__main__":
150
  main()