from transformers import AutoTokenizer, AutoModelForCausalLM import torch from datetime import datetime import os class HFModel: def __init__(self, model_name): parts = model_name.split("/") self.friendly_name = parts[1] self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda() self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) self.chat_history = [] self.log_file = f"chat_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md" def generate_response(self, input_text, max_length=100, skip_special_tokens=True): inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) outputs = self.model.generate(**inputs, max_length=max_length) response = self.tokenizer.decode(outputs[0], skip_special_tokens=skip_special_tokens).strip() return response def stream_response(self, input_text, max_length=100, skip_special_tokens=True): inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) for output in self.model.generate(**inputs, max_length=max_length, do_stream=True): response = self.tokenizer.decode(output, skip_special_tokens=skip_special_tokens).strip() yield response def chat(self, user_input, max_length=100, skip_special_tokens=True): # Add user input to chat history self.chat_history.append({"role": "user", "content": user_input}) # Generate model response model_response = self.generate_response(user_input, max_length=max_length, skip_special_tokens=skip_special_tokens) # Add model response to chat history self.chat_history.append({"role": "assistant", "content": model_response}) # Save chat log self.save_chat_log() return model_response def save_chat_log(self): with open(self.log_file, "a", encoding="utf-8") as f: for entry in self.chat_history[-2:]: # Save only the latest interaction role = entry["role"] content = entry["content"] f.write(f"**{role.capitalize()}:**\n\n{content}\n\n---\n\n") def clear_chat_history(self): self.chat_history = [] print("Chat history cleared.") def print_chat_history(self): for entry in self.chat_history: role = entry["role"] content = entry["content"] print(f"{role.capitalize()}: {content}\n")