File size: 1,655 Bytes
1eecf37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertConfig, BertForSequenceClassification
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing_extensions import Annotated
from Sentiment import app
from .model import model, tokenizer

app_security = HTTPBasic()
api_key = os.getenv("API_KEY")

if not api_key:
  raise ValueError("API_KEY is missing.")

i2l = {0: 'positive', 1: 'neutral', 2: 'negative'}

class AnalyzeRequest(BaseModel):
  text: str

class AnalyzeResponse(BaseModel):
  text: str
  label: str
  score: float

@app.get("/checkhealth", tags=["CheckHealth"])
def checkhealth():
  return  "Sentiment API is running."

@app.post("/predict", tags=["Analyze"], summary="Analyze text from prompt", response_model=AnalyzeResponse)
def predict(creds: Annotated[HTTPBasicCredentials, Depends(app_security)], data: AnalyzeRequest):
  if creds.password != api_key:
    print(creds.password, api_key)
    raise HTTPException(
      status_code=status.HTTP_401_UNAUTHORIZED,
      detail="Incorrect Password",
      headers={"WWW-Authenticate": "Basic"}
    )

  text = data.text
  test_sample = tokenizer.encode(text)
  test_sample = torch.LongTensor(test_sample).view(1, -1).to(model.device)
  logits = model(test_sample)[0]
  label_index = torch.topk(logits, k=1, dim=-1)[1].squeeze().item()
  label = i2l[label_index]
  score = f'{F.softmax(logits, dim=-1).squeeze()[label_index]:.3f}'
  return {"text":text, "label": label, "score": score}