File size: 1,203 Bytes
729b0f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import argparse
import base64
import io
import os
import pickle
import requests
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"USING DEVICE: {DEVICE}")
tokenizer = AutoTokenizer.from_pretrained("xTRam1/safe-guard-classifier")
model = AutoModelForSequenceClassification.from_pretrained(
"xTRam1/safe-guard-classifier"
)
classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
truncation=True,
max_length=512,
device=torch.device(DEVICE),
)
app = FastAPI()
@app.post("/generate")
async def generate(request: dict):
input = request["text"]
print("INPUT:", input)
result = classifier(input)
print("RESULT:", result)
return JSONResponse(content={"text": input, "result": result})
if __name__ == "__main__":
# print("here")
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
port = args.port
uvicorn.run(app, host="127.0.0.1", port=port)
|