Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
import torch | |
from detoxify import Detoxify | |
import asyncio | |
from fastapi.concurrency import run_in_threadpool | |
class Guardrail: | |
def __init__(self): | |
tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection") | |
model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection") | |
self.classifier = pipeline( | |
"text-classification", | |
model=model, | |
tokenizer=tokenizer, | |
truncation=True, | |
max_length=512, | |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
) | |
async def guard(self, prompt): | |
return await run_in_threadpool(self.classifier, prompt) | |
def determine_level(self, label, score): | |
if label == "SAFE": | |
return 0, "safe" | |
else: | |
if score > 0.9: | |
return 4, "high" | |
elif score > 0.75: | |
return 3, "medium" | |
elif score > 0.5: | |
return 2, "low" | |
else: | |
return 1, "very low" | |
class TextPrompt(BaseModel): | |
prompt: str | |
class ClassificationResult(BaseModel): | |
label: str | |
score: float | |
level: int | |
severity_label: str | |
class ToxicityResult(BaseModel): | |
toxicity: float | |
severe_toxicity: float | |
obscene: float | |
threat: float | |
insult: float | |
identity_attack: float | |
class TopicBannerClassifier: | |
def __init__(self): | |
self.classifier = pipeline( | |
"zero-shot-classification", | |
model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", | |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
) | |
self.hypothesis_template = "This text is about {}" | |
self.classes_verbalized = ["politics", "economy", "entertainment", "environment"] | |
async def classify(self, text): | |
return await run_in_threadpool( | |
self.classifier, | |
text, | |
self.classes_verbalized, | |
hypothesis_template=self.hypothesis_template, | |
multi_label=False | |
) | |
class TopicBannerResult(BaseModel): | |
sequence: str | |
labels: list | |
scores: list | |
app = FastAPI() | |
guardrail = Guardrail() | |
toxicity_classifier = Detoxify('original') | |
topic_banner_classifier = TopicBannerClassifier() | |
async def classify_toxicity(text_prompt: TextPrompt): | |
try: | |
result = await run_in_threadpool(toxicity_classifier.predict, text_prompt.prompt) | |
return { | |
"toxicity": result['toxicity'], | |
"severe_toxicity": result['severe_toxicity'], | |
"obscene": result['obscene'], | |
"threat": result['threat'], | |
"insult": result['insult'], | |
"identity_attack": result['identity_attack'] | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def classify_text(text_prompt: TextPrompt): | |
try: | |
result = await guardrail.guard(text_prompt.prompt) | |
label = result[0]['label'] | |
score = result[0]['score'] | |
level, severity_label = guardrail.determine_level(label, score) | |
return {"label": label, "score": score, "level": level, "severity_label": severity_label} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def classify_topic_banner(text_prompt: TextPrompt): | |
try: | |
result = await topic_banner_classifier.classify(text_prompt.prompt) | |
return { | |
"sequence": result["sequence"], | |
"labels": result["labels"], | |
"scores": result["scores"] | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |