TxAgent-Api / src /txagent /txagent.py
Ali2206's picture
Update src/txagent/txagent.py
adac5ab verified
import os
import logging
import torch
from typing import Dict, Optional, List, Union
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from sentence_transformers import SentenceTransformer
# Configure logging for Hugging Face Spaces
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("TxAgent")
class TxAgent:
def __init__(self,
model_name: str,
rag_model_name: str,
tool_files_dict: Optional[Dict] = None,
enable_finish: bool = True,
enable_rag: bool = False,
force_finish: bool = True,
enable_checker: bool = True,
step_rag_num: int = 4,
seed: Optional[int] = None):
# Initialization parameters
self.model_name = model_name
self.rag_model_name = rag_model_name
self.tool_files_dict = tool_files_dict or {}
self.enable_finish = enable_finish
self.enable_rag = enable_rag
self.force_finish = force_finish
self.enable_checker = enable_checker
self.step_rag_num = step_rag_num
self.seed = seed
# Device setup
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Models
self.model = None
self.tokenizer = None
self.rag_model = None
# Prompts
self.chat_prompt = "You are a helpful assistant for user chat."
logger.info(f"Initialized TxAgent with model: {model_name}")
def init_model(self):
"""Initialize all models and components"""
try:
self.load_llm_model()
if self.enable_rag:
self.load_rag_model()
logger.info("Models initialized successfully")
except Exception as e:
logger.error(f"Model initialization failed: {str(e)}")
raise
def load_llm_model(self):
"""Load the main LLM model"""
try:
logger.info(f"Loading LLM model: {self.model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
device_map="auto",
trust_remote_code=True
)
logger.info(f"LLM model loaded on {self.device}")
except Exception as e:
logger.error(f"Failed to load LLM model: {str(e)}")
raise
def load_rag_model(self):
"""Load the RAG model"""
try:
logger.info(f"Loading RAG model: {self.rag_model_name}")
self.rag_model = SentenceTransformer(
self.rag_model_name,
device=str(self.device)
)
logger.info("RAG model loaded successfully")
except Exception as e:
logger.error(f"Failed to load RAG model: {str(e)}")
raise
def chat(self, message: str, history: Optional[List[Dict]] = None,
temperature: float = 0.7, max_new_tokens: int = 512) -> str:
"""Handle chat conversations"""
try:
conversation = []
# Initialize with system prompt
conversation.append({"role": "system", "content": self.chat_prompt})
# Add history if provided
if history:
for msg in history:
conversation.append({"role": msg["role"], "content": msg["content"]})
# Add current message
conversation.append({"role": "user", "content": message})
# Generate response
inputs = self.tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt"
).to(self.device)
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
outputs = self.model.generate(
inputs,
generation_config=generation_config
)
# Decode and clean up response
response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
return response.strip()
except Exception as e:
logger.error(f"Chat failed: {str(e)}")
raise RuntimeError(f"Chat failed: {str(e)}")
def cleanup(self):
"""Clean up resources"""
try:
if hasattr(self, 'model'):
del self.model
if hasattr(self, 'rag_model'):
del self.rag_model
torch.cuda.empty_cache()
logger.info("Resources cleaned up")
except Exception as e:
logger.error(f"Cleanup failed: {str(e)}")
raise
def __del__(self):
"""Destructor to ensure proper cleanup"""
self.cleanup()