nli-implementation / handler.py
startificial's picture
Update handler.py
3dc1585 verified
raw
history blame
5.72 kB
import os
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from typing import Dict, List, Any # <-- ADD THIS LINE
class EndpointHandler():
def __init__(self, model_id: str):
"""
Initializes the handler by loading the model and tokenizer.
Args:
model_id (str): The Hugging Face model ID (e.g., "MoritzLaurer/DeBERTa-v3-base-mnli")
This is automatically passed by the Inference Endpoint infrastructure.
"""
print(f"Loading model '{model_id}'...")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForSequenceClassification.from_pretrained(model_id)
# Move model to the determined device
self.model.to(self.device)
# Set model to evaluation mode for consistent inference
self.model.eval()
print("Model and tokenizer loaded successfully.")
# --- Determine Label Order ---
# Preferred: Dynamically get labels from model config
try:
# Sort by ID to ensure consistent order if dict isn't ordered
sorted_labels = sorted(self.model.config.id2label.items())
self.label_names = [label for _, label in sorted_labels]
print(f"Using label names from model config: {self.label_names}")
# Basic validation for NLI task
if len(self.label_names) != 3:
print(f"Warning: Expected 3 labels for NLI, but model config has {len(self.label_names)}. Proceeding with model's labels.")
if not any("entail" in l.lower() for l in self.label_names) or \
not any("neutral" in l.lower() for l in self.label_names) or \
not any("contra" in l.lower() for l in self.label_names):
print(f"Warning: Model labels {self.label_names} might not match standard NLI labels ('entailment', 'neutral', 'contradiction').")
except AttributeError:
# Fallback: Use the explicitly requested labels if config is missing/malformed
self.label_names = ["entailment", "neutral", "contradiction"]
print(f"Warning: Could not read labels from model config. Falling back to default: {self.label_names}")
print("Ensure this order matches the actual output order of the model!")
print(f"Configured label order for output: {self.label_names}")
# Corrected type hints in the signature below
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any] | List[Dict[str, Any]]:
"""
Handles inference requests.
Args:
data (Dict[str, Any]): The input data payload from the request.
Expected keys: "premise" (str) and "hypothesis" (str).
Can optionally be nested under "inputs".
Returns:
Dict[str, Any] | List[Dict[str, Any]]: A dictionary containing error info,
or a list of dictionaries, each mapping
a label name to its probability score.
"""
# --- Input Parsing ---
inputs = data.get("inputs", data) # Allow for optional "inputs" nesting
premise = inputs.get("premise")
hypothesis = inputs.get("hypothesis")
# Basic input validation
if not premise or not isinstance(premise, str):
return {"error": "Missing or invalid 'premise' key in input. Expected a string."}
if not hypothesis or not isinstance(hypothesis, str):
return {"error": "Missing or invalid 'hypothesis' key in input. Expected a string."}
# --- Tokenization ---
# Tokenize the premise-hypothesis pair
try:
tokenized_inputs = self.tokenizer(
premise,
hypothesis,
return_tensors="pt", # Return PyTorch tensors
truncation=True, # Truncate if longer than max length
padding=True, # Pad to the longest sequence in the batch (or max_length)
max_length=self.tokenizer.model_max_length # Use model's max length
)
except Exception as e:
print(f"Error during tokenization: {e}")
return {"error": f"Failed to tokenize input: {e}"}
# Move tokenized inputs to the same device as the model
tokenized_inputs = {k: v.to(self.device) for k, v in tokenized_inputs.items()}
# --- Inference ---
try:
with torch.no_grad(): # Disable gradient calculations for efficiency
outputs = self.model(**tokenized_inputs)
logits = outputs.logits
# Apply Softmax to get probabilities
probabilities = torch.softmax(logits, dim=-1)
# Move probabilities to CPU and convert to list
# Squeeze or index [0] if processing single pairs (typical for endpoints)
scores = probabilities.cpu().numpy()[0].tolist()
# --- Format Output ---
# Pair labels with their corresponding scores
result = [{"label": label, "score": score} for label, score in zip(self.label_names, scores)]
return result
except Exception as e:
print(f"Error during model inference: {e}")
# Consider logging the full traceback here in a real deployment
# import traceback
# traceback.print_exc()
return {"error": f"Model inference failed: {e}"}