File size: 5,422 Bytes
80b0f9f
fd39839
cf95a11
fd39839
 
 
61c414a
adac5ab
3cfe99a
 
adac5ab
3cfe99a
fd39839
61c414a
fd39839
 
 
 
 
 
 
 
adac5ab
 
 
fd39839
 
 
 
 
 
 
 
 
adac5ab
 
fd39839
 
 
 
 
 
 
 
 
 
 
 
adac5ab
61c414a
fd39839
 
3cfe99a
 
adac5ab
 
 
3cfe99a
adac5ab
3cfe99a
61c414a
fd39839
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adac5ab
fd39839
61c414a
fd39839
 
 
 
adac5ab
 
 
 
fd39839
 
adac5ab
fd39839
61c414a
fd39839
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adac5ab
fd39839
61c414a
fd39839
 
3cfe99a
 
 
 
 
 
adac5ab
3cfe99a
adac5ab
3cfe99a
61c414a
fd39839
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import logging
import torch
from typing import Dict, Optional, List, Union
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from sentence_transformers import SentenceTransformer

# Configure logging for Hugging Face Spaces
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("TxAgent")

class TxAgent:
    def __init__(self, 
                 model_name: str,
                 rag_model_name: str,
                 tool_files_dict: Optional[Dict] = None,
                 enable_finish: bool = True,
                 enable_rag: bool = False,
                 force_finish: bool = True,
                 enable_checker: bool = True,
                 step_rag_num: int = 4,
                 seed: Optional[int] = None):
        
        # Initialization parameters
        self.model_name = model_name
        self.rag_model_name = rag_model_name
        self.tool_files_dict = tool_files_dict or {}
        self.enable_finish = enable_finish
        self.enable_rag = enable_rag
        self.force_finish = force_finish
        self.enable_checker = enable_checker
        self.step_rag_num = step_rag_num
        self.seed = seed
        
        # Device setup
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Models
        self.model = None
        self.tokenizer = None
        self.rag_model = None
        
        # Prompts
        self.chat_prompt = "You are a helpful assistant for user chat."
        
        logger.info(f"Initialized TxAgent with model: {model_name}")

    def init_model(self):
        """Initialize all models and components"""
        try:
            self.load_llm_model()
            if self.enable_rag:
                self.load_rag_model()
            logger.info("Models initialized successfully")
        except Exception as e:
            logger.error(f"Model initialization failed: {str(e)}")
            raise

    def load_llm_model(self):
        """Load the main LLM model"""
        try:
            logger.info(f"Loading LLM model: {self.model_name}")
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                trust_remote_code=True
            )
            
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
                device_map="auto",
                trust_remote_code=True
            )
            logger.info(f"LLM model loaded on {self.device}")
        except Exception as e:
            logger.error(f"Failed to load LLM model: {str(e)}")
            raise

    def load_rag_model(self):
        """Load the RAG model"""
        try:
            logger.info(f"Loading RAG model: {self.rag_model_name}")
            self.rag_model = SentenceTransformer(
                self.rag_model_name,
                device=str(self.device)
            )
            logger.info("RAG model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load RAG model: {str(e)}")
            raise

    def chat(self, message: str, history: Optional[List[Dict]] = None, 
             temperature: float = 0.7, max_new_tokens: int = 512) -> str:
        """Handle chat conversations"""
        try:
            conversation = []
            
            # Initialize with system prompt
            conversation.append({"role": "system", "content": self.chat_prompt})
            
            # Add history if provided
            if history:
                for msg in history:
                    conversation.append({"role": msg["role"], "content": msg["content"]})
            
            # Add current message
            conversation.append({"role": "user", "content": message})
            
            # Generate response
            inputs = self.tokenizer.apply_chat_template(
                conversation,
                add_generation_prompt=True,
                return_tensors="pt"
            ).to(self.device)
            
            generation_config = GenerationConfig(
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id
            )
            
            outputs = self.model.generate(
                inputs,
                generation_config=generation_config
            )
            
            # Decode and clean up response
            response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
            return response.strip()
            
        except Exception as e:
            logger.error(f"Chat failed: {str(e)}")
            raise RuntimeError(f"Chat failed: {str(e)}")

    def cleanup(self):
        """Clean up resources"""
        try:
            if hasattr(self, 'model'):
                del self.model
            if hasattr(self, 'rag_model'):
                del self.rag_model
            torch.cuda.empty_cache()
            logger.info("Resources cleaned up")
        except Exception as e:
            logger.error(f"Cleanup failed: {str(e)}")
            raise

    def __del__(self):
        """Destructor to ensure proper cleanup"""
        self.cleanup()