File size: 5,724 Bytes
c25d9aa
 
 
3dc1585
c25d9aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dc1585
c25d9aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from typing import Dict, List, Any # <-- ADD THIS LINE

class EndpointHandler():
    def __init__(self, model_id: str):
        """
        Initializes the handler by loading the model and tokenizer.

        Args:
            model_id (str): The Hugging Face model ID (e.g., "MoritzLaurer/DeBERTa-v3-base-mnli")
                            This is automatically passed by the Inference Endpoint infrastructure.
        """
        print(f"Loading model '{model_id}'...")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_id)

        # Move model to the determined device
        self.model.to(self.device)
        # Set model to evaluation mode for consistent inference
        self.model.eval()
        print("Model and tokenizer loaded successfully.")

        # --- Determine Label Order ---
        # Preferred: Dynamically get labels from model config
        try:
            # Sort by ID to ensure consistent order if dict isn't ordered
            sorted_labels = sorted(self.model.config.id2label.items())
            self.label_names = [label for _, label in sorted_labels]
            print(f"Using label names from model config: {self.label_names}")
            # Basic validation for NLI task
            if len(self.label_names) != 3:
                 print(f"Warning: Expected 3 labels for NLI, but model config has {len(self.label_names)}. Proceeding with model's labels.")
            if not any("entail" in l.lower() for l in self.label_names) or \
               not any("neutral" in l.lower() for l in self.label_names) or \
               not any("contra" in l.lower() for l in self.label_names):
                 print(f"Warning: Model labels {self.label_names} might not match standard NLI labels ('entailment', 'neutral', 'contradiction').")

        except AttributeError:
            # Fallback: Use the explicitly requested labels if config is missing/malformed
            self.label_names = ["entailment", "neutral", "contradiction"]
            print(f"Warning: Could not read labels from model config. Falling back to default: {self.label_names}")
            print("Ensure this order matches the actual output order of the model!")

        print(f"Configured label order for output: {self.label_names}")


    # Corrected type hints in the signature below
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any] | List[Dict[str, Any]]:
        """
        Handles inference requests.

        Args:
            data (Dict[str, Any]): The input data payload from the request.
                                   Expected keys: "premise" (str) and "hypothesis" (str).
                                   Can optionally be nested under "inputs".

        Returns:
            Dict[str, Any] | List[Dict[str, Any]]: A dictionary containing error info,
                                                    or a list of dictionaries, each mapping
                                                    a label name to its probability score.
        """
        # --- Input Parsing ---
        inputs = data.get("inputs", data) # Allow for optional "inputs" nesting
        premise = inputs.get("premise")
        hypothesis = inputs.get("hypothesis")

        # Basic input validation
        if not premise or not isinstance(premise, str):
            return {"error": "Missing or invalid 'premise' key in input. Expected a string."}
        if not hypothesis or not isinstance(hypothesis, str):
            return {"error": "Missing or invalid 'hypothesis' key in input. Expected a string."}

        # --- Tokenization ---
        # Tokenize the premise-hypothesis pair
        try:
            tokenized_inputs = self.tokenizer(
                premise,
                hypothesis,
                return_tensors="pt", # Return PyTorch tensors
                truncation=True,     # Truncate if longer than max length
                padding=True,        # Pad to the longest sequence in the batch (or max_length)
                max_length=self.tokenizer.model_max_length # Use model's max length
            )
        except Exception as e:
             print(f"Error during tokenization: {e}")
             return {"error": f"Failed to tokenize input: {e}"}


        # Move tokenized inputs to the same device as the model
        tokenized_inputs = {k: v.to(self.device) for k, v in tokenized_inputs.items()}

        # --- Inference ---
        try:
            with torch.no_grad(): # Disable gradient calculations for efficiency
                outputs = self.model(**tokenized_inputs)
                logits = outputs.logits

            # Apply Softmax to get probabilities
            probabilities = torch.softmax(logits, dim=-1)

            # Move probabilities to CPU and convert to list
            # Squeeze or index [0] if processing single pairs (typical for endpoints)
            scores = probabilities.cpu().numpy()[0].tolist()

            # --- Format Output ---
            # Pair labels with their corresponding scores
            result = [{"label": label, "score": score} for label, score in zip(self.label_names, scores)]

            return result

        except Exception as e:
            print(f"Error during model inference: {e}")
            # Consider logging the full traceback here in a real deployment
            # import traceback
            # traceback.print_exc()
            return {"error": f"Model inference failed: {e}"}