deberta_api / main.py
AISimplyExplained's picture
added Topic Banner
5c19b8d verified
raw
history blame
4.2 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch
from detoxify import Detoxify
import asyncio
from fastapi.concurrency import run_in_threadpool
class Guardrail:
def __init__(self):
tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
self.classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
truncation=True,
max_length=512,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
async def guard(self, prompt):
return await run_in_threadpool(self.classifier, prompt)
def determine_level(self, label, score):
if label == "SAFE":
return 0, "safe"
else:
if score > 0.9:
return 4, "high"
elif score > 0.75:
return 3, "medium"
elif score > 0.5:
return 2, "low"
else:
return 1, "very low"
class TextPrompt(BaseModel):
prompt: str
class ClassificationResult(BaseModel):
label: str
score: float
level: int
severity_label: str
class ToxicityResult(BaseModel):
toxicity: float
severe_toxicity: float
obscene: float
threat: float
insult: float
identity_attack: float
class TopicBannerClassifier:
def __init__(self):
self.classifier = pipeline(
"zero-shot-classification",
model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
self.hypothesis_template = "This text is about {}"
self.classes_verbalized = ["politics", "economy", "entertainment", "environment"]
async def classify(self, text):
return await run_in_threadpool(
self.classifier,
text,
self.classes_verbalized,
hypothesis_template=self.hypothesis_template,
multi_label=False
)
class TopicBannerResult(BaseModel):
sequence: str
labels: list
scores: list
app = FastAPI()
guardrail = Guardrail()
toxicity_classifier = Detoxify('original')
topic_banner_classifier = TopicBannerClassifier()
@app.post("/api/models/toxicity/classify", response_model=ToxicityResult)
async def classify_toxicity(text_prompt: TextPrompt):
try:
result = await run_in_threadpool(toxicity_classifier.predict, text_prompt.prompt)
return {
"toxicity": result['toxicity'],
"severe_toxicity": result['severe_toxicity'],
"obscene": result['obscene'],
"threat": result['threat'],
"insult": result['insult'],
"identity_attack": result['identity_attack']
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/models/PromptInjection/classify", response_model=ClassificationResult)
async def classify_text(text_prompt: TextPrompt):
try:
result = await guardrail.guard(text_prompt.prompt)
label = result[0]['label']
score = result[0]['score']
level, severity_label = guardrail.determine_level(label, score)
return {"label": label, "score": score, "level": level, "severity_label": severity_label}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
async def classify_topic_banner(text_prompt: TextPrompt):
try:
result = await topic_banner_classifier.classify(text_prompt.prompt)
return {
"sequence": result["sequence"],
"labels": result["labels"],
"scores": result["scores"]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)