Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from txagent import TxAgent
|
| 3 |
-
from tooluniverse import ToolUniverse
|
| 4 |
-
import os
|
| 5 |
import logging
|
|
|
|
| 6 |
|
| 7 |
# Configure logging
|
| 8 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -13,16 +12,18 @@ class TxAgentApp:
|
|
| 13 |
self.agent = self._initialize_agent()
|
| 14 |
|
| 15 |
def _initialize_agent(self):
|
| 16 |
-
"""Initialize the TxAgent with
|
| 17 |
try:
|
| 18 |
-
logger.info("Initializing TxAgent
|
| 19 |
agent = TxAgent(
|
| 20 |
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
| 21 |
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
| 26 |
)
|
| 27 |
logger.info("Model loading complete")
|
| 28 |
return agent
|
|
@@ -64,11 +65,7 @@ with gr.Blocks(
|
|
| 64 |
with gr.Column(scale=2):
|
| 65 |
chatbot = gr.Chatbot(
|
| 66 |
height=650,
|
| 67 |
-
bubble_full_width=False
|
| 68 |
-
avatar_images=(
|
| 69 |
-
"https://example.com/user.png", # Replace with actual avatars
|
| 70 |
-
"https://example.com/bot.png"
|
| 71 |
-
)
|
| 72 |
)
|
| 73 |
with gr.Column(scale=1):
|
| 74 |
with gr.Accordion("⚙️ Parameters", open=False):
|
|
@@ -116,6 +113,5 @@ if __name__ == "__main__":
|
|
| 116 |
).launch(
|
| 117 |
server_name="0.0.0.0",
|
| 118 |
server_port=7860,
|
| 119 |
-
share=False
|
| 120 |
-
favicon_path="icon.png" # Add favicon
|
| 121 |
)
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from txagent import TxAgent
|
|
|
|
|
|
|
| 3 |
import logging
|
| 4 |
+
import os
|
| 5 |
|
| 6 |
# Configure logging
|
| 7 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 12 |
self.agent = self._initialize_agent()
|
| 13 |
|
| 14 |
def _initialize_agent(self):
|
| 15 |
+
"""Initialize the TxAgent with proper parameters"""
|
| 16 |
try:
|
| 17 |
+
logger.info("Initializing TxAgent...")
|
| 18 |
agent = TxAgent(
|
| 19 |
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
| 20 |
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
| 21 |
+
# Remove unsupported parameters
|
| 22 |
+
force_finish=True,
|
| 23 |
+
enable_checker=True,
|
| 24 |
+
step_rag_num=10,
|
| 25 |
+
seed=42,
|
| 26 |
+
additional_default_tools=["DirectResponse", "RequireClarification"]
|
| 27 |
)
|
| 28 |
logger.info("Model loading complete")
|
| 29 |
return agent
|
|
|
|
| 65 |
with gr.Column(scale=2):
|
| 66 |
chatbot = gr.Chatbot(
|
| 67 |
height=650,
|
| 68 |
+
bubble_full_width=False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
)
|
| 70 |
with gr.Column(scale=1):
|
| 71 |
with gr.Accordion("⚙️ Parameters", open=False):
|
|
|
|
| 113 |
).launch(
|
| 114 |
server_name="0.0.0.0",
|
| 115 |
server_port=7860,
|
| 116 |
+
share=False
|
|
|
|
| 117 |
)
|