File size: 3,995 Bytes
fc0c635
 
 
 
5d2f587
fc0c635
 
 
 
 
 
 
 
9a8affc
 
 
 
 
 
 
 
 
fc0c635
9a8affc
 
 
 
 
 
 
 
 
fc0c635
 
9a8affc
 
 
 
 
 
 
 
fc0c635
 
9a8affc
 
 
 
fc0c635
9a8affc
 
fc0c635
9a8affc
 
 
fc0c635
9a8affc
 
fc0c635
9a8affc
 
 
 
fc0c635
 
9a8affc
 
 
 
 
 
 
 
 
 
fc0c635
 
9a8affc
 
 
 
 
 
fc0c635
 
9a8affc
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datetime import datetime
import os
from .mew_log import log_info, log_error  # Import your custom logging methods

class HFModel:
    def __init__(self, model_name):
        parts = model_name.split("/")
        self.friendly_name = parts[1]
        self.chat_history = []
        self.log_file = f"chat_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"

        try:
            log_info(f"=== Loading Model: {self.friendly_name} ===")
            self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            log_info(f"=== Model Loaded Successfully: {self.friendly_name} ===")
        except Exception as e:
            log_error(f"ERROR Loading Model: {e}")
            raise

    def generate_response(self, input_text, max_length=100, skip_special_tokens=True):
        try:
            inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
            outputs = self.model.generate(**inputs, max_length=max_length)
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=skip_special_tokens).strip()
            log_info(f"Generated Response: {response}")
            return response
        except Exception as e:
            log_error(f"ERROR Generating Response: {e}")
            raise

    def stream_response(self, input_text, max_length=100, skip_special_tokens=True):
        try:
            inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
            for output in self.model.generate(**inputs, max_length=max_length, do_stream=True):
                response = self.tokenizer.decode(output, skip_special_tokens=skip_special_tokens).strip()
                yield response
        except Exception as e:
            log_error(f"ERROR Streaming Response: {e}")
            raise

    def chat(self, user_input, max_length=100, skip_special_tokens=True):
        try:
            # Add user input to chat history
            self.chat_history.append({"role": "user", "content": user_input})
            log_info(f"User Input: {user_input}")

            # Generate model response
            model_response = self.generate_response(user_input, max_length=max_length, skip_special_tokens=skip_special_tokens)

            # Add model response to chat history
            self.chat_history.append({"role": "assistant", "content": model_response})
            log_info(f"Assistant Response: {model_response}")

            # Save chat log
            self.save_chat_log()

            return model_response
        except Exception as e:
            log_error(f"ERROR in Chat: {e}")
            raise

    def save_chat_log(self):
        try:
            with open(self.log_file, "a", encoding="utf-8") as f:
                for entry in self.chat_history[-2:]:  # Save only the latest interaction
                    role = entry["role"]
                    content = entry["content"]
                    f.write(f"**{role.capitalize()}:**\n\n{content}\n\n---\n\n")
            log_info(f"Chat log saved to {self.log_file}")
        except Exception as e:
            log_error(f"ERROR Saving Chat Log: {e}")
            raise

    def clear_chat_history(self):
        try:
            self.chat_history = []
            log_info("Chat history cleared.")
        except Exception as e:
            log_error(f"ERROR Clearing Chat History: {e}")
            raise

    def print_chat_history(self):
        try:
            for entry in self.chat_history:
                role = entry["role"]
                content = entry["content"]
                print(f"{role.capitalize()}: {content}\n")
            log_info("Printed chat history to console.")
        except Exception as e:
            log_error(f"ERROR Printing Chat History: {e}")
            raise