File size: 2,117 Bytes
bbcc78a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from typing import Dict, List, Any
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import pipeline, AutoTokenizer


class EndpointHandler():
  def __init__(self, path=""):
    # load the optimized model
    model = ORTModelForSequenceClassification.from_pretrained(path)
    tokenizer = AutoTokenizer.from_pretrained(path)
    # create inference pipeline
    self.pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)

  def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]:
    """
    Args:
      data: A dictionary containing the payload for inference.
        - stop: List of stop words (optional).
        - frequency_penalty: Penalty for frequent words (optional).
        - presence_penalty: Penalty for words not in the vocabulary (optional).
        - min_p: Minimum probability threshold (optional).
        - messages: List of dictionaries containing conversation messages.
          - role: String indicating the role ("system" or "user").
          - content: String containing the message text.

    Returns:
      A list containing a single list of predictions. Each prediction is a dictionary with:
        - label: A string representing the predicted class.
        - score: A float between 0 and 1 indicating the model's confidence.
    """
    stop_words = data.get("stop", [])  # Get stop words or use an empty list if not provided
    parameters = {
        "stop_words": stop_words,
        "frequency_penalty": data.get("frequency_penalty", 0),  # Use default 0 if not provided
        "presence_penalty": data.get("presence_penalty", 0),  # Use default 0 if not provided
        "min_p": data.get("min_p", 1),  # Use default 1 if not provided
    }

    # Extract messages for inference
    messages = data["messages"]
    inputs = [{"role": message["role"], "text": message["content"]} for message in messages]

    # Pass extracted messages and parameters to the pipeline
    prediction = self.pipeline(inputs, **parameters)

    # Return the prediction (already a list with a single list)
    return prediction