File size: 1,173 Bytes
867fc5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from typing import Dict, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class EndpointHandler:
def __init__(self, path=""):
# Load the tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(path)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data: A dictionary with the key 'inputs' containing the input text.
Returns:
A dictionary with the generated text under the key 'generated_text'.
"""
# Extract input text
input_text = data.get("inputs", "")
if not input_text:
return {"error": "No input provided"}
# Tokenize the input
inputs = self.tokenizer(input_text, return_tensors="pt")
# Generate text
with torch.no_grad():
outputs = self.model.generate(**inputs, max_length=100)
# Decode the generated tokens
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": generated_text} |