File size: 5,653 Bytes
8c31d70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from .log import log
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from .categories import UNSAFE_CATEGORIES
from .guardrail_core import ContentSafetyGuardrail, GuardrailRunner
from .misc import misc, Color, timer
SAFE = Color.green("SAFE")
UNSAFE = Color.red("UNSAFE")
DEFAULT_CHECKPOINT_DIR = "checkpoints/Cosmos-1.0-Guardrail/aegis"
class Aegis(ContentSafetyGuardrail):
def __init__(
self,
checkpoint_dir: str = DEFAULT_CHECKPOINT_DIR,
device="cuda" if torch.cuda.is_available() else "cpu",
) -> None:
self.checkpoint_dir = checkpoint_dir
self.device = device
self.dtype = torch.bfloat16
base_model_id = "meta-llama/LlamaGuard-7b"
aegis_adapter = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0"
base_model = AutoModelForCausalLM.from_pretrained(base_model_id, cache_dir=self.checkpoint_dir)
self.tokenizer = AutoTokenizer.from_pretrained(base_model_id, cache_dir=self.checkpoint_dir)
self.model = PeftModel.from_pretrained(base_model, aegis_adapter, cache_dir=self.checkpoint_dir)
self.model.to(self.device, dtype=self.dtype).eval()
def get_moderation_prompt(self, user_prompt: str) -> str:
"""Create the moderation prompt for the Aegis model."""
full_prompt = f"""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
{"".join([f"{category['instruction']}" for category in UNSAFE_CATEGORIES])}
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
"""
full_prompt += f"User: {user_prompt}\n\n"
full_prompt += """<END CONVERSATION>
Provide your safety assessment for User in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
return full_prompt
def get_aegis_block_message(self, moderation_output: str) -> str:
"""Extract the blocked category and reason from the Aegis model output."""
block_msg = "Prompt blocked by Aegis."
try:
# Extract blocked category and reason
if len(moderation_output.split("\n")) >= 2:
moderation_output_list = moderation_output.split("\n")
blocked_category = int(moderation_output_list[1].lstrip("O")) - 1
if 0 <= blocked_category < len(UNSAFE_CATEGORIES):
blocked_reason = UNSAFE_CATEGORIES[blocked_category]["blocked_reason"]
blocked_category_name = UNSAFE_CATEGORIES[blocked_category]["category"]
block_msg = f"{blocked_category_name}: {blocked_reason}"
except Exception as e:
log.warning(f"Unable to extract blocked category and reason from Aegis output: {e}")
return block_msg
def filter_aegis_output(self, prompt: str) -> tuple[bool, str]:
"""Filter the Aegis model output and return the safety status and message."""
full_prompt = self.get_moderation_prompt(prompt)
inputs = self.tokenizer([full_prompt], add_special_tokens=False, return_tensors="pt").to(self.device)
output = self.model.generate(**inputs, max_new_tokens=100, pad_token_id=self.tokenizer.eos_token_id)
prompt_len = inputs["input_ids"].shape[-1]
moderation_output = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
if "unsafe" in moderation_output.lower():
block_msg = self.get_aegis_block_message(moderation_output)
return False, block_msg
else:
return True, ""
def is_safe(self, prompt: str) -> tuple[bool, str]:
"""Check if the input prompt is safe according to the Aegis model."""
try:
return self.filter_aegis_output(prompt)
except Exception as e:
log.error(f"Unexpected error occurred when running Aegis guardrail: {e}")
return True, "Unexpected error occurred when running Aegis guardrail."
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, required=True, help="Input prompt")
parser.add_argument(
"--checkpoint_dir",
type=str,
help="Path to the Aegis checkpoint folder",
default=DEFAULT_CHECKPOINT_DIR,
)
return parser.parse_args()
def main(args):
aegis = Aegis(checkpoint_dir=args.checkpoint_dir)
runner = GuardrailRunner(safety_models=[aegis])
with timer("aegis safety check"):
safety, message = runner.run_safety_check(args.prompt)
log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
log.info(f"Message: {message}") if not safety else None
if __name__ == "__main__":
args = parse_args()
main(args)
|