my-ai-space / app /models /predict.py
prometheus04's picture
AI Space Builder: Space content update for prometheus04/my-ai-space
fd82a77 verified
raw
history blame contribute delete
671 Bytes
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import json
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def predict(text):
inputs = tokenizer(text, return_tensors="pt")
outputs = model(inputs["input_ids"], attention_mask=inputs["attention_mask"])
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)
predicted_class = torch.argmax(probabilities)
return {"text": text, "prediction": predicted_class.item(), "probability": probabilities.tolist()}