Spaces:
Running
Running
File size: 4,558 Bytes
fc0c635 5d07573 fc0c635 9a8affc a8b4dfc 9a8affc 5d07573 9a8affc fc0c635 9a8affc 29fabcc 9a8affc 5d07573 9a8affc fc0c635 9a8affc 5d07573 9a8affc fc0c635 9a8affc fc0c635 9a8affc fc0c635 9a8affc fc0c635 9a8affc fc0c635 9a8affc 5d07573 9a8affc fc0c635 9a8affc 5d07573 9a8affc fc0c635 9a8affc 5d07573 9a8affc fc0c635 9a8affc 5d07573 9a8affc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datetime import datetime
import os
from mew_log import log_info, log_error # Import your custom logging methods
class HFModel:
def __init__(self, model_name):
parts = model_name.split("/")
self.friendly_name = parts[1]
self.chat_history = []
self.log_file = f"chat_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
try:
log_info(f"=== Loading Model: {self.friendly_name} ===")
self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16) #.cuda()
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
log_info(f"=== Model Loaded Successfully: {self.friendly_name} ===")
except Exception as e:
log_error(f"ERROR Loading Model: {e}", e)
raise
def generate_response(self, input_text, max_length=100, skip_special_tokens=True):
try:
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(
**inputs,
max_length=max_length,
pad_token_id=self.tokenizer.eos_token_id, # Ensure proper padding
do_sample=True, # Enable sampling for more diverse outputs
top_k=50, # Limit sampling to top-k tokens
top_p=0.95, # Use nucleus sampling
temperature=0.7, # Control randomness
)
#inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
#outputs = self.model.generate(**inputs, max_length=max_length)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=skip_special_tokens).strip()
log_info(f"Generated Response: {response}")
return response
except Exception as e:
log_error(f"ERROR Generating Response: {e}", e)
raise
def stream_response(self, input_text, max_length=100, skip_special_tokens=True):
try:
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
for output in self.model.generate(**inputs, max_length=max_length, do_stream=True):
response = self.tokenizer.decode(output, skip_special_tokens=skip_special_tokens).strip()
yield response
except Exception as e:
log_error(f"ERROR Streaming Response: {e}", e)
raise
def chat(self, user_input, max_length=100, skip_special_tokens=True):
try:
# Add user input to chat history
self.chat_history.append({"role": "user", "content": user_input})
log_info(f"User Input: {user_input}")
# Generate model response
model_response = self.generate_response(user_input, max_length=max_length, skip_special_tokens=skip_special_tokens)
# Add model response to chat history
self.chat_history.append({"role": "assistant", "content": model_response})
log_info(f"Assistant Response: {model_response}")
# Save chat log
self.save_chat_log()
return model_response
except Exception as e:
log_error(f"ERROR in Chat: {e}", e)
raise
def save_chat_log(self):
try:
with open(self.log_file, "a", encoding="utf-8") as f:
for entry in self.chat_history[-2:]: # Save only the latest interaction
role = entry["role"]
content = entry["content"]
f.write(f"**{role.capitalize()}:**\n\n{content}\n\n---\n\n")
log_info(f"Chat log saved to {self.log_file}")
except Exception as e:
log_error(f"ERROR Saving Chat Log: {e}", e)
raise
def clear_chat_history(self):
try:
self.chat_history = []
log_info("Chat history cleared.")
except Exception as e:
log_error(f"ERROR Clearing Chat History: {e}", e)
raise
def print_chat_history(self):
try:
for entry in self.chat_history:
role = entry["role"]
content = entry["content"]
print(f"{role.capitalize()}: {content}\n")
log_info("Printed chat history to console.")
except Exception as e:
log_error(f"ERROR Printing Chat History: {e}", e)
raise |