File size: 5,422 Bytes
80b0f9f fd39839 cf95a11 fd39839 61c414a adac5ab 3cfe99a adac5ab 3cfe99a fd39839 61c414a fd39839 adac5ab fd39839 adac5ab fd39839 adac5ab 61c414a fd39839 3cfe99a adac5ab 3cfe99a adac5ab 3cfe99a 61c414a fd39839 adac5ab fd39839 61c414a fd39839 adac5ab fd39839 adac5ab fd39839 61c414a fd39839 adac5ab fd39839 61c414a fd39839 3cfe99a adac5ab 3cfe99a adac5ab 3cfe99a 61c414a fd39839 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import logging
import torch
from typing import Dict, Optional, List, Union
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from sentence_transformers import SentenceTransformer
# Configure logging for Hugging Face Spaces
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("TxAgent")
class TxAgent:
def __init__(self,
model_name: str,
rag_model_name: str,
tool_files_dict: Optional[Dict] = None,
enable_finish: bool = True,
enable_rag: bool = False,
force_finish: bool = True,
enable_checker: bool = True,
step_rag_num: int = 4,
seed: Optional[int] = None):
# Initialization parameters
self.model_name = model_name
self.rag_model_name = rag_model_name
self.tool_files_dict = tool_files_dict or {}
self.enable_finish = enable_finish
self.enable_rag = enable_rag
self.force_finish = force_finish
self.enable_checker = enable_checker
self.step_rag_num = step_rag_num
self.seed = seed
# Device setup
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Models
self.model = None
self.tokenizer = None
self.rag_model = None
# Prompts
self.chat_prompt = "You are a helpful assistant for user chat."
logger.info(f"Initialized TxAgent with model: {model_name}")
def init_model(self):
"""Initialize all models and components"""
try:
self.load_llm_model()
if self.enable_rag:
self.load_rag_model()
logger.info("Models initialized successfully")
except Exception as e:
logger.error(f"Model initialization failed: {str(e)}")
raise
def load_llm_model(self):
"""Load the main LLM model"""
try:
logger.info(f"Loading LLM model: {self.model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
device_map="auto",
trust_remote_code=True
)
logger.info(f"LLM model loaded on {self.device}")
except Exception as e:
logger.error(f"Failed to load LLM model: {str(e)}")
raise
def load_rag_model(self):
"""Load the RAG model"""
try:
logger.info(f"Loading RAG model: {self.rag_model_name}")
self.rag_model = SentenceTransformer(
self.rag_model_name,
device=str(self.device)
)
logger.info("RAG model loaded successfully")
except Exception as e:
logger.error(f"Failed to load RAG model: {str(e)}")
raise
def chat(self, message: str, history: Optional[List[Dict]] = None,
temperature: float = 0.7, max_new_tokens: int = 512) -> str:
"""Handle chat conversations"""
try:
conversation = []
# Initialize with system prompt
conversation.append({"role": "system", "content": self.chat_prompt})
# Add history if provided
if history:
for msg in history:
conversation.append({"role": msg["role"], "content": msg["content"]})
# Add current message
conversation.append({"role": "user", "content": message})
# Generate response
inputs = self.tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt"
).to(self.device)
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
outputs = self.model.generate(
inputs,
generation_config=generation_config
)
# Decode and clean up response
response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
return response.strip()
except Exception as e:
logger.error(f"Chat failed: {str(e)}")
raise RuntimeError(f"Chat failed: {str(e)}")
def cleanup(self):
"""Clean up resources"""
try:
if hasattr(self, 'model'):
del self.model
if hasattr(self, 'rag_model'):
del self.rag_model
torch.cuda.empty_cache()
logger.info("Resources cleaned up")
except Exception as e:
logger.error(f"Cleanup failed: {str(e)}")
raise
def __del__(self):
"""Destructor to ensure proper cleanup"""
self.cleanup() |