safeguard / aihack /launch_model.py
sijju's picture
Upload folder using huggingface_hub
729b0f4 verified
raw
history blame
1.2 kB
import argparse
import base64
import io
import os
import pickle
import requests
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"USING DEVICE: {DEVICE}")
tokenizer = AutoTokenizer.from_pretrained("xTRam1/safe-guard-classifier")
model = AutoModelForSequenceClassification.from_pretrained(
"xTRam1/safe-guard-classifier"
)
classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
truncation=True,
max_length=512,
device=torch.device(DEVICE),
)
app = FastAPI()
@app.post("/generate")
async def generate(request: dict):
input = request["text"]
print("INPUT:", input)
result = classifier(input)
print("RESULT:", result)
return JSONResponse(content={"text": input, "result": result})
if __name__ == "__main__":
# print("here")
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
port = args.port
uvicorn.run(app, host="127.0.0.1", port=port)