AISimplyExplained commited on
Commit
5c19b8d
·
verified ·
1 Parent(s): 338f4c1

added Topic Banner

Browse files
Files changed (1) hide show
  1. main.py +42 -6
main.py CHANGED
@@ -10,7 +10,6 @@ class Guardrail:
10
  def __init__(self):
11
  tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
12
  model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
13
-
14
  self.classifier = pipeline(
15
  "text-classification",
16
  model=model,
@@ -45,10 +44,6 @@ class ClassificationResult(BaseModel):
45
  level: int
46
  severity_label: str
47
 
48
- app = FastAPI()
49
- guardrail = Guardrail()
50
- toxicity_classifier = Detoxify('original')
51
-
52
  class ToxicityResult(BaseModel):
53
  toxicity: float
54
  severe_toxicity: float
@@ -57,6 +52,35 @@ class ToxicityResult(BaseModel):
57
  insult: float
58
  identity_attack: float
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  @app.post("/api/models/toxicity/classify", response_model=ToxicityResult)
61
  async def classify_toxicity(text_prompt: TextPrompt):
62
  try:
@@ -83,6 +107,18 @@ async def classify_text(text_prompt: TextPrompt):
83
  except Exception as e:
84
  raise HTTPException(status_code=500, detail=str(e))
85
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  if __name__ == "__main__":
87
  import uvicorn
88
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
10
  def __init__(self):
11
  tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
12
  model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection")
 
13
  self.classifier = pipeline(
14
  "text-classification",
15
  model=model,
 
44
  level: int
45
  severity_label: str
46
 
 
 
 
 
47
  class ToxicityResult(BaseModel):
48
  toxicity: float
49
  severe_toxicity: float
 
52
  insult: float
53
  identity_attack: float
54
 
55
+ class TopicBannerClassifier:
56
+ def __init__(self):
57
+ self.classifier = pipeline(
58
+ "zero-shot-classification",
59
+ model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
60
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ )
62
+ self.hypothesis_template = "This text is about {}"
63
+ self.classes_verbalized = ["politics", "economy", "entertainment", "environment"]
64
+
65
+ async def classify(self, text):
66
+ return await run_in_threadpool(
67
+ self.classifier,
68
+ text,
69
+ self.classes_verbalized,
70
+ hypothesis_template=self.hypothesis_template,
71
+ multi_label=False
72
+ )
73
+
74
+ class TopicBannerResult(BaseModel):
75
+ sequence: str
76
+ labels: list
77
+ scores: list
78
+
79
+ app = FastAPI()
80
+ guardrail = Guardrail()
81
+ toxicity_classifier = Detoxify('original')
82
+ topic_banner_classifier = TopicBannerClassifier()
83
+
84
  @app.post("/api/models/toxicity/classify", response_model=ToxicityResult)
85
  async def classify_toxicity(text_prompt: TextPrompt):
86
  try:
 
107
  except Exception as e:
108
  raise HTTPException(status_code=500, detail=str(e))
109
 
110
+ @app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
111
+ async def classify_topic_banner(text_prompt: TextPrompt):
112
+ try:
113
+ result = await topic_banner_classifier.classify(text_prompt.prompt)
114
+ return {
115
+ "sequence": result["sequence"],
116
+ "labels": result["labels"],
117
+ "scores": result["scores"]
118
+ }
119
+ except Exception as e:
120
+ raise HTTPException(status_code=500, detail=str(e))
121
+
122
  if __name__ == "__main__":
123
  import uvicorn
124
+ uvicorn.run(app, host="0.0.0.0", port=8000)