|
import gradio as gr |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
import torch |
|
import sys |
|
from pathlib import Path |
|
import os |
|
FILE = Path(__file__).resolve() |
|
ROOT = FILE.parents[0] |
|
if str(ROOT) not in sys.path: |
|
sys.path.append(str(ROOT)) |
|
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) |
|
|
|
from models.common import DetectMultiBackend |
|
from utils.general import (check_img_size, non_max_suppression, scale_boxes) |
|
from utils.plots import Annotator, colors |
|
from utils.torch_utils import select_device |
|
|
|
|
|
device = select_device('') |
|
model = DetectMultiBackend('./weights/nsfw_detector_e_rok.pt', device=device, dnn=False, data=None, fp16=False) |
|
stride, names, pt = model.stride, model.names, model.pt |
|
imgsz = check_img_size((640, 640), s=stride) |
|
|
|
def process_image(image, conf_threshold, iou_threshold, label_mode): |
|
|
|
im = torch.from_numpy(image).to(device).permute(2, 0, 1) |
|
im = im.half() if model.fp16 else im.float() |
|
im /= 255 |
|
if len(im.shape) == 3: |
|
im = im[None] |
|
|
|
|
|
im = torch.nn.functional.interpolate(im, size=imgsz, mode='bilinear', align_corners=False) |
|
|
|
|
|
pred = model(im, augment=False, visualize=False) |
|
if isinstance(pred, list): |
|
pred = pred[0] |
|
|
|
|
|
pred = non_max_suppression(pred, conf_threshold, iou_threshold, None, False, max_det=1000) |
|
|
|
|
|
img = image.copy() |
|
|
|
harmful_label_list = [] |
|
annotations = [] |
|
|
|
for i, det in enumerate(pred): |
|
if len(det): |
|
|
|
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], img.shape).round() |
|
|
|
|
|
for *xyxy, conf, cls in reversed(det): |
|
c = int(cls) |
|
if c != 6: |
|
harmful_label_list.append(c) |
|
|
|
annotation = { |
|
'xyxy': xyxy, |
|
'conf': conf, |
|
'cls': c, |
|
'label': f"{names[c]} {conf:.2f}" if label_mode == "Draw Confidence" else f"{names[c]}" |
|
} |
|
annotations.append(annotation) |
|
|
|
if harmful_label_list: |
|
gr.Error("Warning, this is a harmful image.") |
|
|
|
img = cv2.GaussianBlur(img, (125, 125), 0) |
|
else: |
|
gr.Info('This is a safe image.') |
|
|
|
|
|
annotator = Annotator(img, line_width=3, example=str(names)) |
|
|
|
for ann in annotations: |
|
if label_mode == "Draw box": |
|
annotator.box_label(ann['xyxy'], None, color=colors(ann['cls'], True)) |
|
elif label_mode in ["Draw Label", "Draw Confidence"]: |
|
annotator.box_label(ann['xyxy'], ann['label'], color=colors(ann['cls'], True)) |
|
elif label_mode == "Censor Predictions": |
|
cv2.rectangle(img, (int(ann['xyxy'][0]), int(ann['xyxy'][1])), (int(ann['xyxy'][2]), int(ann['xyxy'][3])), (0, 0, 0), -1) |
|
|
|
return annotator.result() |
|
|
|
def detect_nsfw(input_image, conf_threshold, iou_threshold, label_mode): |
|
if isinstance(input_image, str): |
|
response = requests.get(input_image) |
|
image = Image.open(BytesIO(response.content)) |
|
else: |
|
image = Image.fromarray(input_image) |
|
|
|
image = np.array(image) |
|
if len(image.shape) == 2: |
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
|
elif image.shape[2] == 4: |
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) |
|
|
|
|
|
image = cv2.resize(image, imgsz) |
|
|
|
processed_image = process_image(image, conf_threshold, iou_threshold, label_mode) |
|
return processed_image |
|
|
|
|
|
demo = gr.Interface( |
|
fn=detect_nsfw, |
|
inputs=[ |
|
gr.Image(type="numpy", label="Upload an image or enter a URL"), |
|
gr.Slider(0, 1, value=0.45, label="Confidence Threshold"), |
|
gr.Slider(0, 1, value=0.45, label="Overlap Threshold"), |
|
gr.Dropdown(["Draw box", "Draw Label", "Draw Confidence", "Censor Predictions"], label="Label Display Mode", value="Draw box") |
|
], |
|
outputs=gr.Image(type="numpy", label="Processed Image"), |
|
title="YOLOv9 NSFW Content Detection", |
|
description="Upload an image or enter a URL to detect NSFW content using YOLOv9." |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0") |