Charm_15 / inference.py
GeminiFan207's picture
Create inference.py
8362c4c verified
import torch
from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast
# Paths to your fine-tuned model and tokenizer (update these)
MODEL_DIR = "./mixtral_finetuned" # Directory from your training script
TOKENIZER_JSON = "./mixtral_finetuned/tokenizer.json" # Custom tokenizer file
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
class Charm15Inference:
def __init__(self, model_dir=MODEL_DIR, tokenizer_json=TOKENIZER_JSON):
"""Initialize model and tokenizer for inference."""
try:
# Load tokenizer from JSON (assumes your custom BPE or fine-tuned output)
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_json)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model with optimizations
self.model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16, # Match your training dtype
device_map="auto", # Auto-distribute across GPU/CPU
low_cpu_mem_usage=True # Reduce RAM usage
).to(device)
print(f"Loaded model from {model_dir} and tokenizer from {tokenizer_json}")
except Exception as e:
print(f"Error loading model/tokenizer: {e}")
raise
def generate_response(self, prompt, max_length=2048, temperature=0.7, top_k=50, top_p=0.95):
"""Generate a response from the model."""
try:
# Tokenize input
inputs = self.tokenizer(prompt, return_tensors="pt").to(device)
# Generate output with your earlier generation config in mind
output = self.model.generate(
**inputs,
max_length=max_length, # Aligned with your 2048/4096 configs
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True, # Sampling for variety
repetition_penalty=1.1, # From your generation config
no_repeat_ngram_size=2, # Prevent repetition
use_cache=True # Speed up inference
)
return self.tokenizer.decode(output[0], skip_special_tokens=True)
except Exception as e:
print(f"Generation error: {e}")
return "Sorry, I couldn’t generate a response."
if __name__ == "__main__":
# Initialize inference class
try:
infer = Charm15Inference()
except Exception as e:
print(f"Initialization failed: {e}")
exit(1)
# Interactive loop
print("Chat with Charm 15 (type 'exit' or 'quit' to stop):")
while True:
user_input = input("User: ")
if user_input.lower() in ["exit", "quit"]:
print("Goodbye!")
break
if not user_input.strip():
print("Charm 15: Please say something!")
continue
response = infer.generate_response(user_input)
print("Charm 15:", response)
# Cleanup
del infer.model
torch.cuda.empty_cache()
print("Memory cleared.")