mew77 commited on
Commit
9a8affc
·
verified ·
1 Parent(s): 06c7b9a

Update hf_model.py

Browse files
Files changed (1) hide show
  1. hf_model.py +67 -31
hf_model.py CHANGED
@@ -1,58 +1,94 @@
1
-
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  from datetime import datetime
5
  import os
 
6
 
7
  class HFModel:
8
  def __init__(self, model_name):
9
  parts = model_name.split("/")
10
  self.friendly_name = parts[1]
11
- self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
12
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
13
  self.chat_history = []
14
  self.log_file = f"chat_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
15
 
 
 
 
 
 
 
 
 
 
16
  def generate_response(self, input_text, max_length=100, skip_special_tokens=True):
17
- inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
18
- outputs = self.model.generate(**inputs, max_length=max_length)
19
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=skip_special_tokens).strip()
20
- return response
 
 
 
 
 
21
 
22
  def stream_response(self, input_text, max_length=100, skip_special_tokens=True):
23
- inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
24
- for output in self.model.generate(**inputs, max_length=max_length, do_stream=True):
25
- response = self.tokenizer.decode(output, skip_special_tokens=skip_special_tokens).strip()
26
- yield response
 
 
 
 
27
 
28
  def chat(self, user_input, max_length=100, skip_special_tokens=True):
29
- # Add user input to chat history
30
- self.chat_history.append({"role": "user", "content": user_input})
 
 
31
 
32
- # Generate model response
33
- model_response = self.generate_response(user_input, max_length=max_length, skip_special_tokens=skip_special_tokens)
34
 
35
- # Add model response to chat history
36
- self.chat_history.append({"role": "assistant", "content": model_response})
 
37
 
38
- # Save chat log
39
- self.save_chat_log()
40
 
41
- return model_response
 
 
 
42
 
43
  def save_chat_log(self):
44
- with open(self.log_file, "a", encoding="utf-8") as f:
45
- for entry in self.chat_history[-2:]: # Save only the latest interaction
46
- role = entry["role"]
47
- content = entry["content"]
48
- f.write(f"**{role.capitalize()}:**\n\n{content}\n\n---\n\n")
 
 
 
 
 
49
 
50
  def clear_chat_history(self):
51
- self.chat_history = []
52
- print("Chat history cleared.")
 
 
 
 
53
 
54
  def print_chat_history(self):
55
- for entry in self.chat_history:
56
- role = entry["role"]
57
- content = entry["content"]
58
- print(f"{role.capitalize()}: {content}\n")
 
 
 
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
  from datetime import datetime
4
  import os
5
+ from MewUtilities.mew_log import log_info, log_error # Import your custom logging methods
6
 
7
  class HFModel:
8
  def __init__(self, model_name):
9
  parts = model_name.split("/")
10
  self.friendly_name = parts[1]
 
 
11
  self.chat_history = []
12
  self.log_file = f"chat_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
13
 
14
+ try:
15
+ log_info(f"=== Loading Model: {self.friendly_name} ===")
16
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
17
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
18
+ log_info(f"=== Model Loaded Successfully: {self.friendly_name} ===")
19
+ except Exception as e:
20
+ log_error(f"ERROR Loading Model: {e}")
21
+ raise
22
+
23
  def generate_response(self, input_text, max_length=100, skip_special_tokens=True):
24
+ try:
25
+ inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
26
+ outputs = self.model.generate(**inputs, max_length=max_length)
27
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=skip_special_tokens).strip()
28
+ log_info(f"Generated Response: {response}")
29
+ return response
30
+ except Exception as e:
31
+ log_error(f"ERROR Generating Response: {e}")
32
+ raise
33
 
34
  def stream_response(self, input_text, max_length=100, skip_special_tokens=True):
35
+ try:
36
+ inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
37
+ for output in self.model.generate(**inputs, max_length=max_length, do_stream=True):
38
+ response = self.tokenizer.decode(output, skip_special_tokens=skip_special_tokens).strip()
39
+ yield response
40
+ except Exception as e:
41
+ log_error(f"ERROR Streaming Response: {e}")
42
+ raise
43
 
44
  def chat(self, user_input, max_length=100, skip_special_tokens=True):
45
+ try:
46
+ # Add user input to chat history
47
+ self.chat_history.append({"role": "user", "content": user_input})
48
+ log_info(f"User Input: {user_input}")
49
 
50
+ # Generate model response
51
+ model_response = self.generate_response(user_input, max_length=max_length, skip_special_tokens=skip_special_tokens)
52
 
53
+ # Add model response to chat history
54
+ self.chat_history.append({"role": "assistant", "content": model_response})
55
+ log_info(f"Assistant Response: {model_response}")
56
 
57
+ # Save chat log
58
+ self.save_chat_log()
59
 
60
+ return model_response
61
+ except Exception as e:
62
+ log_error(f"ERROR in Chat: {e}")
63
+ raise
64
 
65
  def save_chat_log(self):
66
+ try:
67
+ with open(self.log_file, "a", encoding="utf-8") as f:
68
+ for entry in self.chat_history[-2:]: # Save only the latest interaction
69
+ role = entry["role"]
70
+ content = entry["content"]
71
+ f.write(f"**{role.capitalize()}:**\n\n{content}\n\n---\n\n")
72
+ log_info(f"Chat log saved to {self.log_file}")
73
+ except Exception as e:
74
+ log_error(f"ERROR Saving Chat Log: {e}")
75
+ raise
76
 
77
  def clear_chat_history(self):
78
+ try:
79
+ self.chat_history = []
80
+ log_info("Chat history cleared.")
81
+ except Exception as e:
82
+ log_error(f"ERROR Clearing Chat History: {e}")
83
+ raise
84
 
85
  def print_chat_history(self):
86
+ try:
87
+ for entry in self.chat_history:
88
+ role = entry["role"]
89
+ content = entry["content"]
90
+ print(f"{role.capitalize()}: {content}\n")
91
+ log_info("Printed chat history to console.")
92
+ except Exception as e:
93
+ log_error(f"ERROR Printing Chat History: {e}")
94
+ raise