Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,371 Bytes
6527198 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
from transformers import pipeline
from PIL import Image
import logging
import os
from reactor_utils import download
from scripts.reactor_logger import logger
def ensure_nsfw_model(nsfwdet_model_path):
"""Download NSFW detection model if it doesn't exist"""
if not os.path.exists(nsfwdet_model_path):
os.makedirs(nsfwdet_model_path)
nd_urls = [
"https://huggingface.co/AdamCodd/vit-base-nsfw-detector/resolve/main/config.json",
"https://huggingface.co/AdamCodd/vit-base-nsfw-detector/resolve/main/model.safetensors",
"https://huggingface.co/AdamCodd/vit-base-nsfw-detector/resolve/main/preprocessor_config.json",
]
for model_url in nd_urls:
model_name = os.path.basename(model_url)
model_path = os.path.join(nsfwdet_model_path, model_name)
download(model_url, model_path, model_name)
SCORE = 0.96
logging.getLogger("transformers").setLevel(logging.ERROR)
def nsfw_image(img_path: str, model_path: str):
ensure_nsfw_model(model_path)
with Image.open(img_path) as img:
predict = pipeline("image-classification", model=model_path)
result = predict(img)
if result[0]["label"] == "nsfw" and result[0]["score"] > SCORE:
logger.status(f"NSFW content detected, skipping...")
return True
return False
|