Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import torch.nn.functional as F | |
from transformers import BertTokenizer, BertConfig, BertForSequenceClassification | |
from fastapi import FastAPI, Depends, HTTPException, status | |
from fastapi.security import HTTPBasic, HTTPBasicCredentials | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from typing_extensions import Annotated | |
from Sentiment import app | |
from .model import model, tokenizer | |
app_security = HTTPBasic() | |
api_key = os.getenv("API_KEY") | |
if not api_key: | |
raise ValueError("API_KEY is missing.") | |
i2l = {0: 'positive', 1: 'neutral', 2: 'negative'} | |
class AnalyzeRequest(BaseModel): | |
text: str | |
class AnalyzeResponse(BaseModel): | |
text: str | |
label: str | |
score: float | |
def checkhealth(): | |
return "Sentiment API is running." | |
def predict(creds: Annotated[HTTPBasicCredentials, Depends(app_security)], data: AnalyzeRequest): | |
if creds.password != api_key: | |
print(creds.password, api_key) | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Incorrect Password", | |
headers={"WWW-Authenticate": "Basic"} | |
) | |
text = data.text | |
test_sample = tokenizer.encode(text) | |
test_sample = torch.LongTensor(test_sample).view(1, -1).to(model.device) | |
logits = model(test_sample)[0] | |
label_index = torch.topk(logits, k=1, dim=-1)[1].squeeze().item() | |
label = i2l[label_index] | |
score = f'{F.softmax(logits, dim=-1).squeeze()[label_index]:.3f}' | |
return {"text":text, "label": label, "score": score} | |