Ali2206 commited on
Commit
dec1312
·
verified ·
1 Parent(s): 417dc33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import gradio as gr
2
  from txagent import TxAgent
3
- from tooluniverse import ToolUniverse
4
- import os
5
  import logging
 
6
 
7
  # Configure logging
8
  logging.basicConfig(level=logging.INFO)
@@ -13,16 +12,18 @@ class TxAgentApp:
13
  self.agent = self._initialize_agent()
14
 
15
  def _initialize_agent(self):
16
- """Initialize the TxAgent with A100 optimizations"""
17
  try:
18
- logger.info("Initializing TxAgent with A100 optimizations...")
19
  agent = TxAgent(
20
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
21
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
22
- device_map="auto",
23
- torch_dtype="auto",
24
- enable_xformers=True,
25
- max_model_len=8192 # Optimized for A100 80GB
 
 
26
  )
27
  logger.info("Model loading complete")
28
  return agent
@@ -64,11 +65,7 @@ with gr.Blocks(
64
  with gr.Column(scale=2):
65
  chatbot = gr.Chatbot(
66
  height=650,
67
- bubble_full_width=False,
68
- avatar_images=(
69
- "https://example.com/user.png", # Replace with actual avatars
70
- "https://example.com/bot.png"
71
- )
72
  )
73
  with gr.Column(scale=1):
74
  with gr.Accordion("⚙️ Parameters", open=False):
@@ -116,6 +113,5 @@ if __name__ == "__main__":
116
  ).launch(
117
  server_name="0.0.0.0",
118
  server_port=7860,
119
- share=False,
120
- favicon_path="icon.png" # Add favicon
121
  )
 
1
  import gradio as gr
2
  from txagent import TxAgent
 
 
3
  import logging
4
+ import os
5
 
6
  # Configure logging
7
  logging.basicConfig(level=logging.INFO)
 
12
  self.agent = self._initialize_agent()
13
 
14
  def _initialize_agent(self):
15
+ """Initialize the TxAgent with proper parameters"""
16
  try:
17
+ logger.info("Initializing TxAgent...")
18
  agent = TxAgent(
19
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
20
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
21
+ # Remove unsupported parameters
22
+ force_finish=True,
23
+ enable_checker=True,
24
+ step_rag_num=10,
25
+ seed=42,
26
+ additional_default_tools=["DirectResponse", "RequireClarification"]
27
  )
28
  logger.info("Model loading complete")
29
  return agent
 
65
  with gr.Column(scale=2):
66
  chatbot = gr.Chatbot(
67
  height=650,
68
+ bubble_full_width=False
 
 
 
 
69
  )
70
  with gr.Column(scale=1):
71
  with gr.Accordion("⚙️ Parameters", open=False):
 
113
  ).launch(
114
  server_name="0.0.0.0",
115
  server_port=7860,
116
+ share=False
 
117
  )