almatkai's picture
Update handler.py
bbcc78a verified
raw
history blame
2.12 kB
from typing import Dict, List, Any
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import pipeline, AutoTokenizer
class EndpointHandler():
def __init__(self, path=""):
# load the optimized model
model = ORTModelForSequenceClassification.from_pretrained(path)
tokenizer = AutoTokenizer.from_pretrained(path)
# create inference pipeline
self.pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]:
"""
Args:
data: A dictionary containing the payload for inference.
- stop: List of stop words (optional).
- frequency_penalty: Penalty for frequent words (optional).
- presence_penalty: Penalty for words not in the vocabulary (optional).
- min_p: Minimum probability threshold (optional).
- messages: List of dictionaries containing conversation messages.
- role: String indicating the role ("system" or "user").
- content: String containing the message text.
Returns:
A list containing a single list of predictions. Each prediction is a dictionary with:
- label: A string representing the predicted class.
- score: A float between 0 and 1 indicating the model's confidence.
"""
stop_words = data.get("stop", []) # Get stop words or use an empty list if not provided
parameters = {
"stop_words": stop_words,
"frequency_penalty": data.get("frequency_penalty", 0), # Use default 0 if not provided
"presence_penalty": data.get("presence_penalty", 0), # Use default 0 if not provided
"min_p": data.get("min_p", 1), # Use default 1 if not provided
}
# Extract messages for inference
messages = data["messages"]
inputs = [{"role": message["role"], "text": message["content"]} for message in messages]
# Pass extracted messages and parameters to the pipeline
prediction = self.pipeline(inputs, **parameters)
# Return the prediction (already a list with a single list)
return prediction