Ali2206 commited on
Commit
dcb29df
·
verified ·
1 Parent(s): 0206b2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -55
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import os
3
  import logging
4
  from txagent import TxAgent
 
5
 
6
  # Configure logging
7
  logging.basicConfig(level=logging.INFO)
@@ -15,9 +16,19 @@ class TxAgentApp:
15
  """Initialize the TxAgent with proper parameters"""
16
  try:
17
  logger.info("Initializing TxAgent...")
 
 
 
 
 
 
 
 
 
18
  agent = TxAgent(
19
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
20
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
21
  enable_finish=True,
22
  enable_rag=True,
23
  enable_summary=False,
@@ -33,6 +44,10 @@ class TxAgentApp:
33
  enable_chat=False,
34
  additional_default_tools=["DirectResponse", "RequireClarification"]
35
  )
 
 
 
 
36
  logger.info("Model loading complete")
37
  return agent
38
  except Exception as e:
@@ -43,10 +58,14 @@ class TxAgentApp:
43
  """Handle streaming responses with Gradio"""
44
  try:
45
  if not isinstance(message, str) or len(message.strip()) <= 10:
46
- yield "Please provide a valid message longer than 10 characters."
47
- return
48
 
49
- response_generator = self.agent.run_gradio_chat(
 
 
 
 
 
50
  message=message.strip(),
51
  history=chat_history,
52
  temperature=temperature,
@@ -56,78 +75,63 @@ class TxAgentApp:
56
  conversation=conversation_state,
57
  max_round=max_round,
58
  seed=42
59
- )
60
-
61
- full_response = ""
62
- for chunk in response_generator:
63
- if isinstance(chunk, dict) and "content" in chunk:
64
- content = chunk["content"]
65
  elif isinstance(chunk, str):
66
- content = chunk
67
  else:
68
- content = str(chunk)
69
 
70
- full_response += content
71
- yield full_response
72
 
73
  except Exception as e:
74
  logger.error(f"Error in respond function: {str(e)}")
75
- yield f"⚠️ Error: {str(e)}"
76
 
77
  def create_demo():
78
  """Create and configure the Gradio interface"""
79
  app = TxAgentApp()
80
 
81
- with gr.Blocks(
82
- title="TxAgent Medical AI",
83
- theme=gr.themes.Soft(spacing_size="sm", radius_size="none")
84
- ) as demo:
85
- gr.Markdown("""<h1 style='text-align: center'>TxAgent Biomedical Assistant</h1>""")
86
 
87
- with gr.Row(equal_height=False):
88
- with gr.Column(scale=2):
89
- chatbot = gr.Chatbot(
90
- height=650,
91
- bubble_full_width=False,
92
- render_markdown=True
93
- )
94
- msg = gr.Textbox(
95
- label="Your medical query",
96
- placeholder="Enter your biomedical question...",
97
- lines=5,
98
- max_lines=10
99
- )
100
-
101
- with gr.Column(scale=1):
102
- with gr.Accordion("⚙️ Parameters", open=False):
103
- temperature = gr.Slider(0, 1, value=0.3, label="Creativity")
104
- max_new_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max Response Length")
105
- max_tokens = gr.Slider(128, 81920, value=81920, step=1024, label="Max Total Tokens")
106
- max_rounds = gr.Slider(1, 30, value=10, step=1, label="Max Rounds")
107
- multi_agent = gr.Checkbox(value=False, label="Multi-Agent Mode")
108
-
109
- submit_btn = gr.Button("Submit", variant="primary")
110
- clear_btn = gr.Button("Clear History")
111
 
112
  conversation_state = gr.State([])
113
 
114
- # Chat interface
115
- msg.submit(
116
  app.respond,
117
- [msg, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
118
  chatbot
119
- ).then(lambda: "", None, msg)
120
 
121
- submit_btn.click(
 
 
122
  app.respond,
123
- [msg, chatbot, temperature, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
124
  chatbot
125
- ).then(lambda: "", None, msg)
126
-
127
- clear_btn.click(
128
- lambda: ([], []),
129
- None,
130
- [chatbot, conversation_state]
131
  )
132
 
133
  return demo
 
2
  import os
3
  import logging
4
  from txagent import TxAgent
5
+ from tooluniverse import ToolUniverse
6
 
7
  # Configure logging
8
  logging.basicConfig(level=logging.INFO)
 
16
  """Initialize the TxAgent with proper parameters"""
17
  try:
18
  logger.info("Initializing TxAgent...")
19
+
20
+ # Initialize default tool files
21
+ tool_files = {
22
+ "opentarget": "opentarget_tools.json",
23
+ "fda_drug_label": "fda_drug_labeling_tools.json",
24
+ "special_tools": "special_tools.json",
25
+ "monarch": "monarch_tools.json"
26
+ }
27
+
28
  agent = TxAgent(
29
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
30
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
31
+ tool_files_dict=tool_files, # This is critical!
32
  enable_finish=True,
33
  enable_rag=True,
34
  enable_summary=False,
 
44
  enable_chat=False,
45
  additional_default_tools=["DirectResponse", "RequireClarification"]
46
  )
47
+
48
+ # Explicitly initialize the model
49
+ agent.init_model()
50
+
51
  logger.info("Model loading complete")
52
  return agent
53
  except Exception as e:
 
58
  """Handle streaming responses with Gradio"""
59
  try:
60
  if not isinstance(message, str) or len(message.strip()) <= 10:
61
+ return chat_history + [("", "Please provide a valid message longer than 10 characters.")]
 
62
 
63
+ # Convert chat history to list of tuples if needed
64
+ if chat_history and isinstance(chat_history[0], dict):
65
+ chat_history = [(h["role"], h["content"]) for h in chat_history if "role" in h and "content" in h]
66
+
67
+ response = ""
68
+ for chunk in self.agent.run_gradio_chat(
69
  message=message.strip(),
70
  history=chat_history,
71
  temperature=temperature,
 
75
  conversation=conversation_state,
76
  max_round=max_round,
77
  seed=42
78
+ ):
79
+ if isinstance(chunk, dict):
80
+ response += chunk.get("content", "")
 
 
 
81
  elif isinstance(chunk, str):
82
+ response += chunk
83
  else:
84
+ response += str(chunk)
85
 
86
+ yield chat_history + [("user", message), ("assistant", response)]
 
87
 
88
  except Exception as e:
89
  logger.error(f"Error in respond function: {str(e)}")
90
+ yield chat_history + [("", f"⚠️ Error: {str(e)}")]
91
 
92
  def create_demo():
93
  """Create and configure the Gradio interface"""
94
  app = TxAgentApp()
95
 
96
+ with gr.Blocks(title="TxAgent Medical AI") as demo:
97
+ gr.Markdown("# TxAgent Biomedical Assistant")
 
 
 
98
 
99
+ chatbot = gr.Chatbot(
100
+ label="Conversation",
101
+ height=600,
102
+ bubble_full_width=False
103
+ )
104
+
105
+ msg = gr.Textbox(
106
+ label="Your medical query",
107
+ placeholder="Enter your biomedical question...",
108
+ lines=3
109
+ )
110
+
111
+ with gr.Row():
112
+ temp = gr.Slider(0, 1, value=0.3, label="Temperature")
113
+ max_new_tokens = gr.Slider(128, 4096, value=1024, label="Max New Tokens")
114
+ max_tokens = gr.Slider(128, 81920, value=81920, label="Max Total Tokens")
115
+ max_rounds = gr.Slider(1, 30, value=10, label="Max Rounds")
116
+ multi_agent = gr.Checkbox(label="Multi-Agent Mode")
117
+
118
+ submit = gr.Button("Submit")
119
+ clear = gr.Button("Clear")
 
 
 
120
 
121
  conversation_state = gr.State([])
122
 
123
+ submit.click(
 
124
  app.respond,
125
+ [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
126
  chatbot
127
+ )
128
 
129
+ clear.click(lambda: [], None, chatbot)
130
+
131
+ msg.submit(
132
  app.respond,
133
+ [msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, conversation_state, max_rounds],
134
  chatbot
 
 
 
 
 
 
135
  )
136
 
137
  return demo