|
from typing import Any |
|
|
|
from transformers import pipeline |
|
|
|
from constants import SAFETY_CHECKER_MODEL |
|
|
|
|
|
class SafetyChecker: |
|
"""A class to check if an image is NSFW or not.""" |
|
|
|
def __init__( |
|
self, |
|
mode_id: str = SAFETY_CHECKER_MODEL, |
|
): |
|
self.classifier = pipeline( |
|
"image-classification", |
|
model=mode_id, |
|
) |
|
|
|
def is_safe( |
|
self, |
|
image: Any, |
|
) -> bool: |
|
pred = self.classifier(image) |
|
scores = {label["label"]: label["score"] for label in pred} |
|
nsfw_score = scores.get("nsfw", 0) |
|
normal_score = scores.get("normal", 0) |
|
print(f"NSFW score: {nsfw_score}, Normal score: {normal_score}") |
|
return normal_score > nsfw_score |
|
|