import sys | |
import os | |
# β Add src to Python path | |
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))) | |
from txagent.txagent import TxAgent | |
def init_agent(): | |
# β Use Hugging Face persistent storage | |
base_dir = "/data" | |
model_cache_dir = os.path.join(base_dir, "hf_cache") | |
tool_cache_dir = os.path.join(base_dir, "tool_cache") | |
# β Ensure the folders exist | |
os.makedirs(model_cache_dir, exist_ok=True) | |
os.makedirs(tool_cache_dir, exist_ok=True) | |
# β Set environment variables so models stay cached after restart | |
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir | |
os.environ["HF_HOME"] = model_cache_dir | |
# β Paths to model + tool definitions | |
model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B" | |
rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B" | |
tool_files_dict = { | |
"new_tool": os.path.join(tool_cache_dir, "new_tool.json") | |
} | |
# β Init agent with config | |
agent = TxAgent( | |
model_name=model_name, | |
rag_model_name=rag_model_name, | |
tool_files_dict=tool_files_dict, | |
force_finish=True, | |
enable_checker=True, | |
step_rag_num=10, | |
seed=100, | |
additional_default_tools=[] | |
) | |
agent.init_model() | |
return agent | |