Ali2206 commited on
Commit
81ad366
·
verified ·
1 Parent(s): c181ca6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -153
app.py CHANGED
@@ -1,178 +1,121 @@
1
- import os
2
- import json
3
- import torch
4
- import logging
5
- import numpy
6
  import gradio as gr
7
- from importlib.resources import files
8
  from txagent import TxAgent
9
  from tooluniverse import ToolUniverse
 
 
10
 
11
  # Configure logging
12
- logging.basicConfig(
13
- level=logging.INFO,
14
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
15
- )
16
  logger = logging.getLogger(__name__)
17
 
18
- # Environment setup
19
- os.environ["MKL_THREADING_LAYER"] = "GNU"
20
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
- current_dir = os.path.dirname(os.path.abspath(__file__))
22
-
23
- # Configuration
24
- CONFIG = {
25
- "model_name": "mims-harvard/TxAgent-T1-Llama-3.1-8B",
26
- "rag_model_name": "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
27
- "tool_files": {
28
- "opentarget": str(files('tooluniverse.data').joinpath('opentarget_tools.json')),
29
- "fda_drug_label": str(files('tooluniverse.data').joinpath('fda_drug_labeling_tools.json')),
30
- "special_tools": str(files('tooluniverse.data').joinpath('special_tools.json')),
31
- "monarch": str(files('tooluniverse.data').joinpath('monarch_tools.json')),
32
- "new_tool": os.path.join(current_dir, 'data', 'new_tool.json')
33
- }
34
- }
35
-
36
  class TxAgentApp:
37
  def __init__(self):
38
- self.agent = None
39
- self.initialize_agent()
40
 
41
- def initialize_agent(self):
42
- """Initialize the TxAgent with proper error handling"""
43
  try:
44
- self.prepare_tool_files()
45
- logger.info("Initializing TxAgent...")
46
-
47
- self.agent = TxAgent(
48
- model_name=CONFIG["model_name"],
49
- rag_model_name=CONFIG["rag_model_name"],
50
- tool_files_dict=CONFIG["tool_files"],
51
- force_finish=True,
52
- enable_checker=True,
53
- step_rag_num=10,
54
- seed=42,
55
- additional_default_tools=["DirectResponse", "RequireClarification"]
56
  )
57
-
58
- logger.info("Initializing model...")
59
- self.agent.init_model()
60
- logger.info("Agent initialization complete")
61
-
62
  except Exception as e:
63
- logger.error(f"Failed to initialize agent: {e}")
64
  raise
65
 
66
- def prepare_tool_files(self):
67
- """Prepare the tool files directory"""
68
  try:
69
- os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True)
70
- if not os.path.exists(CONFIG["tool_files"]["new_tool"]):
71
- logger.info("Creating new_tool.json...")
72
- tu = ToolUniverse()
73
- tools = tu.get_all_tools() if hasattr(tu, "get_all_tools") else getattr(tu, "tools", [])
74
- with open(CONFIG["tool_files"]["new_tool"], "w") as f:
75
- json.dump(tools, f, indent=2)
76
- except Exception as e:
77
- logger.error(f"Failed to prepare tool files: {e}")
78
- raise
79
-
80
- def respond(self, msg, chat_history, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
81
- """Handle user message and generate response"""
82
- try:
83
- if not isinstance(msg, str) or len(msg.strip()) <= 10:
84
- return chat_history + [{"role": "assistant", "content": "Please provide a valid message longer than 10 characters."}]
85
-
86
- message = msg.strip()
87
- chat_history.append({"role": "user", "content": message})
88
- formatted_history = [(m["role"], m["content"]) for m in chat_history if "role" in m and "content" in m]
89
-
90
- logger.info(f"Processing message: {message[:100]}...")
91
-
92
  response_generator = self.agent.run_gradio_chat(
93
  message=message,
94
- history=formatted_history,
95
- temperature=temperature,
96
- max_new_tokens=max_new_tokens,
97
- max_token=max_tokens,
98
- call_agent=multi_agent,
99
- conversation=conversation,
100
- max_round=max_round,
101
- seed=42
102
  )
103
-
104
- collected = ""
105
  for chunk in response_generator:
106
- if isinstance(chunk, dict) and "content" in chunk:
107
- collected += chunk["content"]
108
  elif isinstance(chunk, str):
109
- collected += chunk
110
- elif chunk is not None:
111
- collected += str(chunk)
112
-
113
- chat_history.append({"role": "assistant", "content": collected or "No response generated."})
114
- return chat_history
115
-
116
  except Exception as e:
117
- logger.error(f"Error in respond function: {e}")
118
- chat_history.append({"role": "assistant", "content": f"Error: {str(e)}"})
119
- return chat_history
120
-
121
- def create_demo(self):
122
- """Create and return the Gradio interface"""
123
- with gr.Blocks(title="TxAgent", css=".gr-button { font-size: 18px !important; }") as demo:
124
- gr.Markdown("# TxAgent - Biomedical AI Assistant")
125
-
126
- with gr.Row():
127
- with gr.Column(scale=3):
128
- chatbot = gr.Chatbot(
129
- label="Conversation",
130
- height=600
131
- )
132
- msg = gr.Textbox(
133
- label="Your question",
134
- placeholder="Ask a biomedical question...",
135
- lines=3
136
- )
137
- submit = gr.Button("Ask", variant="primary")
138
-
139
- with gr.Column(scale=1):
140
- temp = gr.Slider(0, 1, value=0.3, label="Temperature")
141
- max_new_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max New Tokens")
142
- max_tokens = gr.Slider(128, 81920, value=81920, step=1024, label="Max Total Tokens")
143
- max_rounds = gr.Slider(1, 30, value=10, step=1, label="Max Rounds")
144
- multi_agent = gr.Checkbox(label="Multi-Agent Mode", value=False)
145
- clear_btn = gr.Button("Clear Chat")
146
-
147
- submit.click(
148
- self.respond,
149
- inputs=[msg, chatbot, temp, max_new_tokens, max_tokens, multi_agent, gr.State([]), max_rounds],
150
- outputs=[chatbot]
151
  )
152
- clear_btn.click(lambda: [], None, chatbot, queue=False)
153
-
154
- # Add a dummy event to ensure the app stays alive
155
- demo.load(lambda: None, None, None)
 
156
 
157
- return demo
158
-
159
- def main():
160
- """Main entry point for the application"""
161
- try:
162
- logger.info("Starting TxAgent application...")
163
- app = TxAgentApp()
164
- demo = app.create_demo()
165
-
166
- logger.info("Launching Gradio interface...")
167
- demo.launch(
168
- server_name="0.0.0.0",
169
- server_port=7860,
170
- share=True,
171
- show_error=True
172
- )
173
- except Exception as e:
174
- logger.error(f"Application failed to start: {e}")
175
- raise
176
-
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  if __name__ == "__main__":
178
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  from txagent import TxAgent
3
  from tooluniverse import ToolUniverse
4
+ import os
5
+ import logging
6
 
7
  # Configure logging
8
+ logging.basicConfig(level=logging.INFO)
 
 
 
9
  logger = logging.getLogger(__name__)
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class TxAgentApp:
12
  def __init__(self):
13
+ self.agent = self._initialize_agent()
 
14
 
15
+ def _initialize_agent(self):
16
+ """Initialize the TxAgent with A100 optimizations"""
17
  try:
18
+ logger.info("Initializing TxAgent with A100 optimizations...")
19
+ agent = TxAgent(
20
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
21
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
22
+ device_map="auto",
23
+ torch_dtype="auto",
24
+ enable_xformers=True,
25
+ max_model_len=8192 # Optimized for A100 80GB
 
 
 
 
26
  )
27
+ logger.info("Model loading complete")
28
+ return agent
 
 
 
29
  except Exception as e:
30
+ logger.error(f"Initialization failed: {str(e)}")
31
  raise
32
 
33
+ def respond(self, message, history):
34
+ """Handle streaming responses with Gradio 5.23"""
35
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  response_generator = self.agent.run_gradio_chat(
37
  message=message,
38
+ history=history,
39
+ temperature=0.3,
40
+ max_new_tokens=2048,
41
+ stream=True
 
 
 
 
42
  )
43
+
 
44
  for chunk in response_generator:
45
+ if isinstance(chunk, dict):
46
+ yield chunk.get("content", "")
47
  elif isinstance(chunk, str):
48
+ yield chunk
 
 
 
 
 
 
49
  except Exception as e:
50
+ logger.error(f"Generation error: {str(e)}")
51
+ yield f"⚠️ Error: {str(e)}"
52
+
53
+ # Initialize the app
54
+ app = TxAgentApp()
55
+
56
+ # Gradio 5.23 interface
57
+ with gr.Blocks(
58
+ title="TxAgent Medical AI",
59
+ theme=gr.themes.Soft(spacing_size="sm", radius_size="none")
60
+ ) as demo:
61
+ gr.Markdown("""<h1 style='text-align: center'>TxAgent Biomedical Assistant</h1>""")
62
+
63
+ with gr.Row(equal_height=False):
64
+ with gr.Column(scale=2):
65
+ chatbot = gr.Chatbot(
66
+ height=650,
67
+ bubble_full_width=False,
68
+ avatar_images=(
69
+ "https://example.com/user.png", # Replace with actual avatars
70
+ "https://example.com/bot.png"
71
+ )
 
 
 
 
 
 
 
 
 
 
 
 
72
  )
73
+ with gr.Column(scale=1):
74
+ with gr.Accordion("⚙️ Parameters", open=False):
75
+ temperature = gr.Slider(0, 1, value=0.3, label="Creativity")
76
+ max_tokens = gr.Slider(128, 4096, value=1024, step=128, label="Max Response Length")
77
+ rag_toggle = gr.Checkbox(value=True, label="Enable RAG")
78
 
79
+ msg = gr.Textbox(
80
+ label="Your medical query",
81
+ placeholder="Enter your biomedical question...",
82
+ lines=5,
83
+ max_lines=10
84
+ )
85
+ submit_btn = gr.Button("Submit", variant="primary")
86
+ clear_btn = gr.Button("Clear History")
87
+
88
+ # Chat interface
89
+ msg.submit(
90
+ app.respond,
91
+ [msg, chatbot],
92
+ chatbot,
93
+ api_name="chat"
94
+ ).then(
95
+ lambda: "", None, msg
96
+ )
97
+
98
+ submit_btn.click(
99
+ app.respond,
100
+ [msg, chatbot],
101
+ chatbot,
102
+ api_name="chat"
103
+ ).then(
104
+ lambda: "", None, msg
105
+ )
106
+
107
+ clear_btn.click(
108
+ lambda: [], None, chatbot
109
+ )
110
+
111
+ # Launch configuration
112
  if __name__ == "__main__":
113
+ demo.queue(
114
+ concurrency_count=5,
115
+ max_size=20
116
+ ).launch(
117
+ server_name="0.0.0.0",
118
+ server_port=7860,
119
+ share=False,
120
+ favicon_path="icon.png" # Add favicon
121
+ )