Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
from detoxify import Detoxify
|
6 |
import asyncio
|
7 |
from fastapi.concurrency import run_in_threadpool
|
8 |
-
from typing import List
|
9 |
|
10 |
class Guardrail:
|
11 |
def __init__(self):
|
@@ -80,6 +80,16 @@ class TopicBannerResult(BaseModel):
|
|
80 |
labels: list
|
81 |
scores: list
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
app = FastAPI()
|
84 |
guardrail = Guardrail()
|
85 |
toxicity_classifier = Detoxify('original')
|
@@ -123,6 +133,35 @@ async def classify_topic_banner(request: TopicBannerRequest):
|
|
123 |
except Exception as e:
|
124 |
raise HTTPException(status_code=500, detail=str(e))
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
if __name__ == "__main__":
|
127 |
import uvicorn
|
128 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
5 |
from detoxify import Detoxify
|
6 |
import asyncio
|
7 |
from fastapi.concurrency import run_in_threadpool
|
8 |
+
from typing import List, Optional
|
9 |
|
10 |
class Guardrail:
|
11 |
def __init__(self):
|
|
|
80 |
labels: list
|
81 |
scores: list
|
82 |
|
83 |
+
class GuardrailsRequest(BaseModel):
|
84 |
+
prompt: str
|
85 |
+
guardrails: List[str]
|
86 |
+
labels: Optional[List[str]] = None
|
87 |
+
|
88 |
+
class GuardrailsResponse(BaseModel):
|
89 |
+
prompt_injection: Optional[ClassificationResult] = None
|
90 |
+
toxicity: Optional[ToxicityResult] = None
|
91 |
+
topic_banner: Optional[TopicBannerResult] = None
|
92 |
+
|
93 |
app = FastAPI()
|
94 |
guardrail = Guardrail()
|
95 |
toxicity_classifier = Detoxify('original')
|
|
|
133 |
except Exception as e:
|
134 |
raise HTTPException(status_code=500, detail=str(e))
|
135 |
|
136 |
+
@app.post("/api/guardrails", response_model=GuardrailsResponse)
|
137 |
+
async def evaluate_guardrails(request: GuardrailsRequest):
|
138 |
+
tasks = []
|
139 |
+
response = GuardrailsResponse()
|
140 |
+
|
141 |
+
if "pi" in request.guardrails:
|
142 |
+
tasks.append(classify_text(TextPrompt(prompt=request.prompt)))
|
143 |
+
if "tox" in request.guardrails:
|
144 |
+
tasks.append(classify_toxicity(TextPrompt(prompt=request.prompt)))
|
145 |
+
if "top" in request.guardrails:
|
146 |
+
if not request.labels:
|
147 |
+
raise HTTPException(status_code=400, detail="Labels are required for topic banner classification")
|
148 |
+
tasks.append(classify_topic_banner(TopicBannerRequest(prompt=request.prompt, labels=request.labels)))
|
149 |
+
|
150 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
151 |
+
|
152 |
+
for result, guardrail in zip(results, request.guardrails):
|
153 |
+
if isinstance(result, Exception):
|
154 |
+
# Handle the exception as needed
|
155 |
+
continue
|
156 |
+
if guardrail == "pi":
|
157 |
+
response.prompt_injection = result
|
158 |
+
elif guardrail == "tox":
|
159 |
+
response.toxicity = result
|
160 |
+
elif guardrail == "top":
|
161 |
+
response.topic_banner = result
|
162 |
+
|
163 |
+
return response
|
164 |
+
|
165 |
if __name__ == "__main__":
|
166 |
import uvicorn
|
167 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|