File size: 1,203 Bytes
729b0f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import argparse
import base64
import io
import os
import pickle

import requests
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"USING DEVICE: {DEVICE}")

tokenizer = AutoTokenizer.from_pretrained("xTRam1/safe-guard-classifier")
model = AutoModelForSequenceClassification.from_pretrained(
    "xTRam1/safe-guard-classifier"
)

classifier = pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    truncation=True,
    max_length=512,
    device=torch.device(DEVICE),
)

app = FastAPI()


@app.post("/generate")
async def generate(request: dict):
    input = request["text"]
    print("INPUT:", input)
    result = classifier(input)
    print("RESULT:", result)
    return JSONResponse(content={"text": input, "result": result})


if __name__ == "__main__":
    # print("here")
    parser = argparse.ArgumentParser()
    parser.add_argument("--port", type=int, default=8000)
    args = parser.parse_args()
    port = args.port
    uvicorn.run(app, host="127.0.0.1", port=port)