0xroyce's picture
Create handler.py
7d333b6 verified
raw
history blame
1.73 kB
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Global variables for model, tokenizer, and device
model = None
tokenizer = None
device = None
def init():
"""
The init function is called once at startup to load the model into memory.
"""
global model, tokenizer, device
# Set your model name or path here
model_name_or_path = "0xroyce/NazareAI-Senior-Marketing-Strategist"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval() # Put model in evaluation mode
def inference(model_inputs: dict) -> dict:
"""
This function is called for every request and should return the model's output.
The input is a dictionary and the output should be a dictionary.
"""
global model, tokenizer, device
# Extract the prompt from the input
prompt = model_inputs.get("prompt", "")
if not prompt:
return {"error": "No prompt provided."}
# Tokenize inputs
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Generate text
# You can adjust parameters like max_new_tokens, temperature, or top_p as needed
output_ids = model.generate(
**inputs,
max_new_tokens=200,
do_sample=True,
top_p=0.9,
temperature=0.7
)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return {"generated_text": output_text}