FaceMask / src /backend /safety_checker.py
tejani's picture
Upload 115 files
79eeb88 verified
raw
history blame contribute delete
751 Bytes
from typing import Any
from transformers import pipeline
from constants import SAFETY_CHECKER_MODEL
class SafetyChecker:
"""A class to check if an image is NSFW or not."""
def __init__(
self,
mode_id: str = SAFETY_CHECKER_MODEL,
):
self.classifier = pipeline(
"image-classification",
model=mode_id,
)
def is_safe(
self,
image: Any,
) -> bool:
pred = self.classifier(image)
scores = {label["label"]: label["score"] for label in pred}
nsfw_score = scores.get("nsfw", 0)
normal_score = scores.get("normal", 0)
print(f"NSFW score: {nsfw_score}, Normal score: {normal_score}")
return normal_score > nsfw_score