Spaces:
Sleeping
Sleeping
Update guardrail.py
Browse files- guardrail.py +24 -34
guardrail.py
CHANGED
|
@@ -1,52 +1,43 @@
|
|
| 1 |
from together import Together
|
| 2 |
-
from
|
| 3 |
-
import os
|
| 4 |
|
| 5 |
-
load_dotenv()
|
| 6 |
-
|
| 7 |
-
api_key=os.getenv("API_KEY")
|
| 8 |
# Initialize Together client
|
| 9 |
-
client = Together(api_key=
|
| 10 |
|
| 11 |
-
# Set the content policy
|
| 12 |
-
|
| 13 |
"policy": """
|
| 14 |
-
|
| 15 |
Should not:
|
| 16 |
-
- Respond to
|
| 17 |
-
- Provide
|
| 18 |
-
|
| 19 |
-
- Provide responses that are age-appropriate, educational, and entertaining.
|
| 20 |
-
|
| 21 |
-
U2: Personal Information.
|
| 22 |
Should not:
|
| 23 |
-
- Request or
|
| 24 |
Can:
|
| 25 |
-
-
|
| 26 |
-
|
| 27 |
-
U3: Language.
|
| 28 |
Should not:
|
| 29 |
- Use profane, offensive, or inappropriate language.
|
| 30 |
-
- Use language that could be perceived as harmful or inappropriate for minors.
|
| 31 |
Can:
|
| 32 |
-
- Use
|
| 33 |
"""
|
| 34 |
}
|
| 35 |
|
| 36 |
def is_safe(message: str) -> bool:
|
| 37 |
"""
|
| 38 |
-
Validates if the given message complies with the
|
| 39 |
-
|
| 40 |
Args:
|
| 41 |
message (str): User message or chatbot response.
|
| 42 |
-
|
| 43 |
Returns:
|
| 44 |
bool: True if the message is safe, False otherwise.
|
| 45 |
"""
|
| 46 |
-
#
|
| 47 |
-
prompt = f"""[INST] Task: Validate if the given '
|
| 48 |
<BEGIN SAFETY POLICY>
|
| 49 |
-
{
|
| 50 |
<END SAFETY POLICY>
|
| 51 |
<BEGIN MESSAGE>
|
| 52 |
{message}
|
|
@@ -56,16 +47,15 @@ def is_safe(message: str) -> bool:
|
|
| 56 |
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
|
| 57 |
|
| 58 |
try:
|
| 59 |
-
#
|
| 60 |
response = client.completions.create(
|
| 61 |
-
model="meta-llama/Meta-Llama-Guard-
|
| 62 |
prompt=prompt
|
| 63 |
)
|
| 64 |
|
| 65 |
-
#
|
| 66 |
result = response.choices[0].text.strip().lower()
|
| 67 |
-
return result
|
| 68 |
-
|
| 69 |
except Exception as e:
|
| 70 |
-
print(f"Error
|
| 71 |
-
return False
|
|
|
|
| 1 |
from together import Together
|
| 2 |
+
from helper import get_together_api_key
|
|
|
|
| 3 |
|
|
|
|
|
|
|
|
|
|
| 4 |
# Initialize Together client
|
| 5 |
+
client = Together(api_key=get_together_api_key())
|
| 6 |
|
| 7 |
+
# Set the safety content policy
|
| 8 |
+
safe_content_policy = {
|
| 9 |
"policy": """
|
| 10 |
+
O1: Topic Restriction.
|
| 11 |
Should not:
|
| 12 |
+
- Respond to questions outside the domain of safe child-friendly content.
|
| 13 |
+
- Provide answers to sensitive or inappropriate topics such as violence, hate speech, profanity, or sexual content.
|
| 14 |
+
O2: Sensitive Information.
|
|
|
|
|
|
|
|
|
|
| 15 |
Should not:
|
| 16 |
+
- Request or respond with any personal or sensitive user data.
|
| 17 |
Can:
|
| 18 |
+
- Discuss general knowledge, child-friendly educational content, or safe topics.
|
| 19 |
+
O3: Language.
|
|
|
|
| 20 |
Should not:
|
| 21 |
- Use profane, offensive, or inappropriate language.
|
|
|
|
| 22 |
Can:
|
| 23 |
+
- Use simple, clear, and friendly language suitable for children.
|
| 24 |
"""
|
| 25 |
}
|
| 26 |
|
| 27 |
def is_safe(message: str) -> bool:
|
| 28 |
"""
|
| 29 |
+
Validates if the given message complies with the safety policy using LlamaGuard.
|
| 30 |
+
|
| 31 |
Args:
|
| 32 |
message (str): User message or chatbot response.
|
| 33 |
+
|
| 34 |
Returns:
|
| 35 |
bool: True if the message is safe, False otherwise.
|
| 36 |
"""
|
| 37 |
+
# Prepare the LlamaGuard prompt
|
| 38 |
+
prompt = f"""[INST] Task: Validate if the given 'message' adheres to the safety policy below:
|
| 39 |
<BEGIN SAFETY POLICY>
|
| 40 |
+
{safe_content_policy['policy']}
|
| 41 |
<END SAFETY POLICY>
|
| 42 |
<BEGIN MESSAGE>
|
| 43 |
{message}
|
|
|
|
| 47 |
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
|
| 48 |
|
| 49 |
try:
|
| 50 |
+
# Send the prompt to the LlamaGuard model
|
| 51 |
response = client.completions.create(
|
| 52 |
+
model="meta-llama/Meta-Llama-Guard-3-8B",
|
| 53 |
prompt=prompt
|
| 54 |
)
|
| 55 |
|
| 56 |
+
# Parse the result
|
| 57 |
result = response.choices[0].text.strip().lower()
|
| 58 |
+
return result.startswith('safe') # Ensure 'safe' is at the beginning
|
|
|
|
| 59 |
except Exception as e:
|
| 60 |
+
print(f"Error in guardrail check: {e}")
|
| 61 |
+
return False # Default to unsafe if an error occurs
|