Sana-1.6B / nsfw_detector.py
VamooseBambel's picture
Huh?
e8ec659 verified
raw
history blame
2.37 kB
import torch
from torchvision import transforms
from transformers import AutoProcessor, FocalNetForImageClassification
from PIL import Image, ImageDraw, ImageFont
import numpy as np
class NSFWDetector:
def __init__(self):
self.model_path = "TostAI/nsfw-image-detection-large"
self.feature_extractor = AutoProcessor.from_pretrained(self.model_path)
self.model = FocalNetForImageClassification.from_pretrained(self.model_path)
self.model.eval()
self.transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.label_to_category = {
"LABEL_0": "Safe",
"LABEL_1": "Questionable",
"LABEL_2": "Unsafe"
}
def check_image(self, image):
# Convert image to RGB if it isn't already
image = image.convert("RGB")
# Process image
inputs = self.feature_extractor(images=image, return_tensors="pt")
# Get prediction
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
confidence, predicted = torch.max(probabilities, 1)
# Get the label
label = self.model.config.id2label[predicted.item()]
category = self.label_to_category.get(label, label)
return category != "Safe", category, confidence.item() * 100
def create_error_image(message="NSFW Content Detected"):
# Create a black image
img = Image.new('RGB', (512, 512), color='black')
draw = ImageDraw.Draw(img)
# Use default font
try:
# Try to get a default system font
font = ImageFont.load_default()
# Calculate text position to center it
text_bbox = draw.textbbox((0, 0), message, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
x = (512 - text_width) // 2
y = (512 - text_height) // 2
# Draw white text
draw.text((x, y), message, fill='white', font=font)
except Exception as e:
print(f"Error adding text to image: {e}")
return img