Spaces:
Sleeping
Sleeping
added Topic Banner
Browse files
main.py
CHANGED
@@ -10,7 +10,6 @@ class Guardrail:
|
|
10 |
def __init__(self):
|
11 |
tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
|
12 |
model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
|
13 |
-
|
14 |
self.classifier = pipeline(
|
15 |
"text-classification",
|
16 |
model=model,
|
@@ -45,10 +44,6 @@ class ClassificationResult(BaseModel):
|
|
45 |
level: int
|
46 |
severity_label: str
|
47 |
|
48 |
-
app = FastAPI()
|
49 |
-
guardrail = Guardrail()
|
50 |
-
toxicity_classifier = Detoxify('original')
|
51 |
-
|
52 |
class ToxicityResult(BaseModel):
|
53 |
toxicity: float
|
54 |
severe_toxicity: float
|
@@ -57,6 +52,35 @@ class ToxicityResult(BaseModel):
|
|
57 |
insult: float
|
58 |
identity_attack: float
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
@app.post("/api/models/toxicity/classify", response_model=ToxicityResult)
|
61 |
async def classify_toxicity(text_prompt: TextPrompt):
|
62 |
try:
|
@@ -83,6 +107,18 @@ async def classify_text(text_prompt: TextPrompt):
|
|
83 |
except Exception as e:
|
84 |
raise HTTPException(status_code=500, detail=str(e))
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
if __name__ == "__main__":
|
87 |
import uvicorn
|
88 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
10 |
def __init__(self):
|
11 |
tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
|
12 |
model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
|
|
|
13 |
self.classifier = pipeline(
|
14 |
"text-classification",
|
15 |
model=model,
|
|
|
44 |
level: int
|
45 |
severity_label: str
|
46 |
|
|
|
|
|
|
|
|
|
47 |
class ToxicityResult(BaseModel):
|
48 |
toxicity: float
|
49 |
severe_toxicity: float
|
|
|
52 |
insult: float
|
53 |
identity_attack: float
|
54 |
|
55 |
+
class TopicBannerClassifier:
|
56 |
+
def __init__(self):
|
57 |
+
self.classifier = pipeline(
|
58 |
+
"zero-shot-classification",
|
59 |
+
model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
|
60 |
+
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
61 |
+
)
|
62 |
+
self.hypothesis_template = "This text is about {}"
|
63 |
+
self.classes_verbalized = ["politics", "economy", "entertainment", "environment"]
|
64 |
+
|
65 |
+
async def classify(self, text):
|
66 |
+
return await run_in_threadpool(
|
67 |
+
self.classifier,
|
68 |
+
text,
|
69 |
+
self.classes_verbalized,
|
70 |
+
hypothesis_template=self.hypothesis_template,
|
71 |
+
multi_label=False
|
72 |
+
)
|
73 |
+
|
74 |
+
class TopicBannerResult(BaseModel):
|
75 |
+
sequence: str
|
76 |
+
labels: list
|
77 |
+
scores: list
|
78 |
+
|
79 |
+
app = FastAPI()
|
80 |
+
guardrail = Guardrail()
|
81 |
+
toxicity_classifier = Detoxify('original')
|
82 |
+
topic_banner_classifier = TopicBannerClassifier()
|
83 |
+
|
84 |
@app.post("/api/models/toxicity/classify", response_model=ToxicityResult)
|
85 |
async def classify_toxicity(text_prompt: TextPrompt):
|
86 |
try:
|
|
|
107 |
except Exception as e:
|
108 |
raise HTTPException(status_code=500, detail=str(e))
|
109 |
|
110 |
+
@app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
|
111 |
+
async def classify_topic_banner(text_prompt: TextPrompt):
|
112 |
+
try:
|
113 |
+
result = await topic_banner_classifier.classify(text_prompt.prompt)
|
114 |
+
return {
|
115 |
+
"sequence": result["sequence"],
|
116 |
+
"labels": result["labels"],
|
117 |
+
"scores": result["scores"]
|
118 |
+
}
|
119 |
+
except Exception as e:
|
120 |
+
raise HTTPException(status_code=500, detail=str(e))
|
121 |
+
|
122 |
if __name__ == "__main__":
|
123 |
import uvicorn
|
124 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|