|
import os |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from typing import Dict, List, Any |
|
|
|
class EndpointHandler(): |
|
def __init__(self, model_id: str): |
|
""" |
|
Initializes the handler by loading the model and tokenizer. |
|
|
|
Args: |
|
model_id (str): The Hugging Face model ID (e.g., "MoritzLaurer/DeBERTa-v3-base-mnli") |
|
This is automatically passed by the Inference Endpoint infrastructure. |
|
""" |
|
print(f"Loading model '{model_id}'...") |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {self.device}") |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_id) |
|
|
|
|
|
self.model.to(self.device) |
|
|
|
self.model.eval() |
|
print("Model and tokenizer loaded successfully.") |
|
|
|
|
|
|
|
try: |
|
|
|
sorted_labels = sorted(self.model.config.id2label.items()) |
|
self.label_names = [label for _, label in sorted_labels] |
|
print(f"Using label names from model config: {self.label_names}") |
|
|
|
if len(self.label_names) != 3: |
|
print(f"Warning: Expected 3 labels for NLI, but model config has {len(self.label_names)}. Proceeding with model's labels.") |
|
if not any("entail" in l.lower() for l in self.label_names) or \ |
|
not any("neutral" in l.lower() for l in self.label_names) or \ |
|
not any("contra" in l.lower() for l in self.label_names): |
|
print(f"Warning: Model labels {self.label_names} might not match standard NLI labels ('entailment', 'neutral', 'contradiction').") |
|
|
|
except AttributeError: |
|
|
|
self.label_names = ["entailment", "neutral", "contradiction"] |
|
print(f"Warning: Could not read labels from model config. Falling back to default: {self.label_names}") |
|
print("Ensure this order matches the actual output order of the model!") |
|
|
|
print(f"Configured label order for output: {self.label_names}") |
|
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any] | List[Dict[str, Any]]: |
|
""" |
|
Handles inference requests. |
|
|
|
Args: |
|
data (Dict[str, Any]): The input data payload from the request. |
|
Expected keys: "premise" (str) and "hypothesis" (str). |
|
Can optionally be nested under "inputs". |
|
|
|
Returns: |
|
Dict[str, Any] | List[Dict[str, Any]]: A dictionary containing error info, |
|
or a list of dictionaries, each mapping |
|
a label name to its probability score. |
|
""" |
|
|
|
inputs = data.get("inputs", data) |
|
premise = inputs.get("premise") |
|
hypothesis = inputs.get("hypothesis") |
|
|
|
|
|
if not premise or not isinstance(premise, str): |
|
return {"error": "Missing or invalid 'premise' key in input. Expected a string."} |
|
if not hypothesis or not isinstance(hypothesis, str): |
|
return {"error": "Missing or invalid 'hypothesis' key in input. Expected a string."} |
|
|
|
|
|
|
|
try: |
|
tokenized_inputs = self.tokenizer( |
|
premise, |
|
hypothesis, |
|
return_tensors="pt", |
|
truncation=True, |
|
padding=True, |
|
max_length=self.tokenizer.model_max_length |
|
) |
|
except Exception as e: |
|
print(f"Error during tokenization: {e}") |
|
return {"error": f"Failed to tokenize input: {e}"} |
|
|
|
|
|
|
|
tokenized_inputs = {k: v.to(self.device) for k, v in tokenized_inputs.items()} |
|
|
|
|
|
try: |
|
with torch.no_grad(): |
|
outputs = self.model(**tokenized_inputs) |
|
logits = outputs.logits |
|
|
|
|
|
probabilities = torch.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
scores = probabilities.cpu().numpy()[0].tolist() |
|
|
|
|
|
|
|
result = [{"label": label, "score": score} for label, score in zip(self.label_names, scores)] |
|
|
|
return result |
|
|
|
except Exception as e: |
|
print(f"Error during model inference: {e}") |
|
|
|
|
|
|
|
return {"error": f"Model inference failed: {e}"} |