mr687
initial app
1eecf37
import os
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertConfig, BertForSequenceClassification
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing_extensions import Annotated
from Sentiment import app
from .model import model, tokenizer
app_security = HTTPBasic()
api_key = os.getenv("API_KEY")
if not api_key:
raise ValueError("API_KEY is missing.")
i2l = {0: 'positive', 1: 'neutral', 2: 'negative'}
class AnalyzeRequest(BaseModel):
text: str
class AnalyzeResponse(BaseModel):
text: str
label: str
score: float
@app.get("/checkhealth", tags=["CheckHealth"])
def checkhealth():
return "Sentiment API is running."
@app.post("/predict", tags=["Analyze"], summary="Analyze text from prompt", response_model=AnalyzeResponse)
def predict(creds: Annotated[HTTPBasicCredentials, Depends(app_security)], data: AnalyzeRequest):
if creds.password != api_key:
print(creds.password, api_key)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect Password",
headers={"WWW-Authenticate": "Basic"}
)
text = data.text
test_sample = tokenizer.encode(text)
test_sample = torch.LongTensor(test_sample).view(1, -1).to(model.device)
logits = model(test_sample)[0]
label_index = torch.topk(logits, k=1, dim=-1)[1].squeeze().item()
label = i2l[label_index]
score = f'{F.softmax(logits, dim=-1).squeeze()[label_index]:.3f}'
return {"text":text, "label": label, "score": score}