|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
import torch |
|
import gradio as gr |
|
|
|
from models.common import DetectMultiBackend, NSFWModel |
|
from utils.torch_utils import select_device |
|
from utils.general import (check_img_size, non_max_suppression, scale_boxes) |
|
from utils.plots import Annotator, colors |
|
|
|
|
|
|
|
nsfw_model = NSFWModel() |
|
|
|
|
|
device = select_device('') |
|
yolo_model = DetectMultiBackend('./weights/nsfw_detector_e_rok.pt', device=device, dnn=False, data=None, fp16=False) |
|
stride, names, pt = yolo_model.stride, yolo_model.names, yolo_model.pt |
|
imgsz = check_img_size((640, 640), s=stride) |
|
|
|
def resize_and_pad(image, target_size): |
|
ih, iw = image.shape[:2] |
|
target_h, target_w = target_size |
|
|
|
|
|
scale = min(target_h/ih, target_w/iw) |
|
|
|
|
|
new_h, new_w = int(ih * scale), int(iw * scale) |
|
|
|
|
|
resized = cv2.resize(image, (new_w, new_h)) |
|
|
|
|
|
pad_h = (target_h - new_h) // 2 |
|
pad_w = (target_w - new_w) // 2 |
|
|
|
|
|
padded = cv2.copyMakeBorder(resized, pad_h, target_h-new_h-pad_h, pad_w, target_w-new_w-pad_w, cv2.BORDER_CONSTANT, value=[0,0,0]) |
|
|
|
return padded |
|
|
|
def process_image_yolo(image, conf_threshold, iou_threshold, label_mode): |
|
|
|
im = torch.from_numpy(image).to(device).permute(2, 0, 1) |
|
im = im.half() if yolo_model.fp16 else im.float() |
|
im /= 255 |
|
if len(im.shape) == 3: |
|
im = im[None] |
|
|
|
|
|
im = torch.nn.functional.interpolate(im, size=imgsz, mode='bilinear', align_corners=False) |
|
|
|
|
|
pred = yolo_model(im, augment=False, visualize=False) |
|
if isinstance(pred, list): |
|
pred = pred[0] |
|
|
|
|
|
pred = non_max_suppression(pred, conf_threshold, iou_threshold, None, False, max_det=1000) |
|
|
|
|
|
img = image.copy() |
|
harmful_label_list = [] |
|
annotations = [] |
|
|
|
for i, det in enumerate(pred): |
|
if len(det): |
|
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], img.shape).round() |
|
|
|
for *xyxy, conf, cls in reversed(det): |
|
c = int(cls) |
|
if c != 6: |
|
harmful_label_list.append(c) |
|
|
|
annotation = { |
|
'xyxy': xyxy, |
|
'conf': conf, |
|
'cls': c, |
|
'label': f"{names[c]} {conf:.2f}" if label_mode == "Draw Confidence" else f"{names[c]}" |
|
} |
|
annotations.append(annotation) |
|
|
|
if 4 in harmful_label_list and 10 in harmful_label_list: |
|
gr.Warning("Warning: This image is featuring underwear.") |
|
elif harmful_label_list: |
|
gr.Error("Warning: This image may contain harmful content.") |
|
img = cv2.GaussianBlur(img, (125, 125), 0) |
|
else: |
|
gr.Info('This image appears to be safe.') |
|
|
|
annotator = Annotator(img, line_width=3, example=str(names)) |
|
|
|
for ann in annotations: |
|
if label_mode == "Draw box": |
|
annotator.box_label(ann['xyxy'], None, color=colors(ann['cls'], True)) |
|
elif label_mode in ["Draw Label", "Draw Confidence"]: |
|
annotator.box_label(ann['xyxy'], ann['label'], color=colors(ann['cls'], True)) |
|
elif label_mode == "Censor Predictions": |
|
cv2.rectangle(img, (int(ann['xyxy'][0]), int(ann['xyxy'][1])), (int(ann['xyxy'][2]), int(ann['xyxy'][3])), (0, 0, 0), -1) |
|
|
|
return annotator.result() |
|
|
|
def detect_nsfw(input_image, detection_mode, conf_threshold=0.3, iou_threshold=0.45, label_mode="Draw box"): |
|
if isinstance(input_image, str): |
|
response = requests.get(input_image) |
|
image = Image.open(BytesIO(response.content)) |
|
else: |
|
image = Image.fromarray(input_image) |
|
|
|
image_np = np.array(image) |
|
if len(image_np.shape) == 2: |
|
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB) |
|
elif image_np.shape[2] == 4: |
|
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB) |
|
|
|
if detection_mode == "Simple Check": |
|
result = nsfw_model.predict(image) |
|
return result, None |
|
else: |
|
image_np = resize_and_pad(image_np, imgsz) |
|
processed_image = process_image_yolo(image_np, conf_threshold, iou_threshold, label_mode) |
|
return "Detailed analysis completed. See the image for results.", processed_image |
|
|