llmgaurdrails / app.py
Sasidhar's picture
Update app.py
86f878d verified
raw
history blame
2.72 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import re
app = FastAPI()
# Configuration for enabled guardrail checks as part of the request payload.
class GuardrailsConfig(BaseModel):
factual_consistency: bool = True
toxicity: bool = True
# Extend with more flags for additional guardrails
# Request model now includes both the response and the configuration.
class CheckRequest(BaseModel):
response: str
config: GuardrailsConfig = GuardrailsConfig() # Default config if not provided
class CheckResponse(BaseModel):
grounded: bool
details: dict
# A simple result class to hold individual check outcomes.
class Result:
def __init__(self):
self.details = {}
def add(self, rule_name: str, passed: bool):
self.details[rule_name] = passed
def grounded(self) -> bool:
# The response is considered "grounded" if all enabled rules pass.
return all(self.details.values())
# Define guardrail rule classes.
class FactualConsistencyRule:
name = "FactualConsistency"
def check(self, response_text: str) -> bool:
# For demonstration: pass if the response contains the word "fact".
return "fact" in response_text.lower()
class ToxicityRule:
name = "Toxicity"
def check(self, response_text: str) -> bool:
# For demonstration: fail if negative words like "hate" or "kill" are found.
return not re.search(r"(hate|kill)", response_text, re.IGNORECASE)
# Manager class to load and execute the enabled guardrail rules.
class GuardrailsManager:
def __init__(self, config: GuardrailsConfig):
self.config = config
self.rules = self.load_rules()
def load_rules(self):
rules = []
if self.config.factual_consistency:
rules.append(FactualConsistencyRule())
if self.config.toxicity:
rules.append(ToxicityRule())
# Add additional rules based on configuration here.
return rules
def check(self, response_text: str) -> Result:
result = Result()
for rule in self.rules:
rule_result = rule.check(response_text)
result.add(rule.name, rule_result)
return result
# Define the POST endpoint for guardrail checking.
@app.post("/guardrails/check", response_model=CheckResponse)
async def check_guardrails(request: CheckRequest):
manager = GuardrailsManager(request.config)
result = manager.check(request.response)
return CheckResponse(grounded=result.grounded(), details=result.details)
@app.get("/")
def home():
return {"Msg": "This is a LLM Gaurdrails!"}
@app.get("/gaurdrails/test")
def test():
return {"Msg": "This is a Test!"}