from typing import Any from transformers import pipeline from constants import SAFETY_CHECKER_MODEL class SafetyChecker: """A class to check if an image is NSFW or not.""" def __init__( self, mode_id: str = SAFETY_CHECKER_MODEL, ): self.classifier = pipeline( "image-classification", model=mode_id, ) def is_safe( self, image: Any, ) -> bool: pred = self.classifier(image) scores = {label["label"]: label["score"] for label in pred} nsfw_score = scores.get("nsfw", 0) normal_score = scores.get("normal", 0) print(f"NSFW score: {nsfw_score}, Normal score: {normal_score}") return normal_score > nsfw_score