|
import os |
|
import logging |
|
import torch |
|
from typing import Dict, Optional, List, Union |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger("TxAgent") |
|
|
|
class TxAgent: |
|
def __init__(self, |
|
model_name: str, |
|
rag_model_name: str, |
|
tool_files_dict: Optional[Dict] = None, |
|
enable_finish: bool = True, |
|
enable_rag: bool = False, |
|
force_finish: bool = True, |
|
enable_checker: bool = True, |
|
step_rag_num: int = 4, |
|
seed: Optional[int] = None): |
|
|
|
|
|
self.model_name = model_name |
|
self.rag_model_name = rag_model_name |
|
self.tool_files_dict = tool_files_dict or {} |
|
self.enable_finish = enable_finish |
|
self.enable_rag = enable_rag |
|
self.force_finish = force_finish |
|
self.enable_checker = enable_checker |
|
self.step_rag_num = step_rag_num |
|
self.seed = seed |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
self.model = None |
|
self.tokenizer = None |
|
self.rag_model = None |
|
|
|
|
|
self.chat_prompt = "You are a helpful assistant for user chat." |
|
|
|
logger.info(f"Initialized TxAgent with model: {model_name}") |
|
|
|
def init_model(self): |
|
"""Initialize all models and components""" |
|
try: |
|
self.load_llm_model() |
|
if self.enable_rag: |
|
self.load_rag_model() |
|
logger.info("Models initialized successfully") |
|
except Exception as e: |
|
logger.error(f"Model initialization failed: {str(e)}") |
|
raise |
|
|
|
def load_llm_model(self): |
|
"""Load the main LLM model""" |
|
try: |
|
logger.info(f"Loading LLM model: {self.model_name}") |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
self.model_name, |
|
trust_remote_code=True |
|
) |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.model_name, |
|
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
logger.info(f"LLM model loaded on {self.device}") |
|
except Exception as e: |
|
logger.error(f"Failed to load LLM model: {str(e)}") |
|
raise |
|
|
|
def load_rag_model(self): |
|
"""Load the RAG model""" |
|
try: |
|
logger.info(f"Loading RAG model: {self.rag_model_name}") |
|
self.rag_model = SentenceTransformer( |
|
self.rag_model_name, |
|
device=str(self.device) |
|
) |
|
logger.info("RAG model loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Failed to load RAG model: {str(e)}") |
|
raise |
|
|
|
def chat(self, message: str, history: Optional[List[Dict]] = None, |
|
temperature: float = 0.7, max_new_tokens: int = 512) -> str: |
|
"""Handle chat conversations""" |
|
try: |
|
conversation = [] |
|
|
|
|
|
conversation.append({"role": "system", "content": self.chat_prompt}) |
|
|
|
|
|
if history: |
|
for msg in history: |
|
conversation.append({"role": msg["role"], "content": msg["content"]}) |
|
|
|
|
|
conversation.append({"role": "user", "content": message}) |
|
|
|
|
|
inputs = self.tokenizer.apply_chat_template( |
|
conversation, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
generation_config = GenerationConfig( |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
do_sample=True, |
|
pad_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
outputs = self.model.generate( |
|
inputs, |
|
generation_config=generation_config |
|
) |
|
|
|
|
|
response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) |
|
return response.strip() |
|
|
|
except Exception as e: |
|
logger.error(f"Chat failed: {str(e)}") |
|
raise RuntimeError(f"Chat failed: {str(e)}") |
|
|
|
def cleanup(self): |
|
"""Clean up resources""" |
|
try: |
|
if hasattr(self, 'model'): |
|
del self.model |
|
if hasattr(self, 'rag_model'): |
|
del self.rag_model |
|
torch.cuda.empty_cache() |
|
logger.info("Resources cleaned up") |
|
except Exception as e: |
|
logger.error(f"Cleanup failed: {str(e)}") |
|
raise |
|
|
|
def __del__(self): |
|
"""Destructor to ensure proper cleanup""" |
|
self.cleanup() |