|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
import torch |
|
import gradio as gr |
|
|
|
from models.common import DetectMultiBackend |
|
from utils.torch_utils import select_device |
|
from utils.general import check_img_size, non_max_suppression, scale_boxes |
|
from utils.plots import Annotator, colors |
|
from config.settings import DETECT_MODEL_PATH |
|
|
|
|
|
|
|
|
|
|
|
device = select_device("") |
|
yolo_model = DetectMultiBackend(DETECT_MODEL_PATH, device=device, dnn=False, data=None, fp16=False) |
|
stride, names, pt = yolo_model.stride, yolo_model.names, yolo_model.pt |
|
imgsz = check_img_size((640, 640), s=stride) |
|
|
|
|
|
def resize_and_pad(image, target_size): |
|
ih, iw = image.shape[:2] |
|
target_h, target_w = target_size |
|
|
|
|
|
scale = min(target_h / ih, target_w / iw) |
|
|
|
|
|
new_h, new_w = int(ih * scale), int(iw * scale) |
|
|
|
|
|
resized = cv2.resize(image, (new_w, new_h)) |
|
|
|
|
|
pad_h = (target_h - new_h) // 2 |
|
pad_w = (target_w - new_w) // 2 |
|
|
|
|
|
padded = cv2.copyMakeBorder( |
|
resized, |
|
pad_h, |
|
target_h - new_h - pad_h, |
|
pad_w, |
|
target_w - new_w - pad_w, |
|
cv2.BORDER_CONSTANT, |
|
value=[0, 0, 0], |
|
) |
|
|
|
return padded |
|
|
|
|
|
def process_image_yolo(image, conf_threshold, iou_threshold, label_mode): |
|
|
|
im = torch.from_numpy(image).to(device).permute(2, 0, 1) |
|
im = im.half() if yolo_model.fp16 else im.float() |
|
im /= 255 |
|
if len(im.shape) == 3: |
|
im = im[None] |
|
|
|
|
|
im = torch.nn.functional.interpolate(im, size=imgsz, mode="bilinear", align_corners=False) |
|
|
|
|
|
pred = yolo_model(im, augment=False, visualize=False) |
|
if isinstance(pred, list): |
|
pred = pred[0] |
|
|
|
|
|
pred = non_max_suppression(pred, conf_threshold, iou_threshold, None, False, max_det=1000) |
|
|
|
|
|
img = image.copy() |
|
harmful_label_list = [] |
|
annotations = [] |
|
label_counts = {} |
|
|
|
for i, det in enumerate(pred): |
|
if len(det): |
|
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], img.shape).round() |
|
|
|
for *xyxy, conf, cls in reversed(det): |
|
c = int(cls) |
|
harmful_label_list.append(c) |
|
label_name = names[c] |
|
label_counts[label_name] = label_counts.get(label_name, 0) + 1 |
|
|
|
annotation = { |
|
"xyxy": xyxy, |
|
"conf": conf, |
|
"cls": c, |
|
"label": ( |
|
f"{names[c]} {conf:.2f}" |
|
if label_mode == "Draw Confidence" |
|
else f"{names[c]}" |
|
), |
|
} |
|
annotations.append(annotation) |
|
|
|
|
|
annotator = Annotator(img, line_width=3, example=str(names)) |
|
|
|
for ann in annotations: |
|
if label_mode == "Draw box": |
|
annotator.box_label(ann["xyxy"], None, color=colors(ann["cls"], True)) |
|
elif label_mode in ["Draw Label", "Draw Confidence"]: |
|
annotator.box_label(ann["xyxy"], ann["label"], color=colors(ann["cls"], True)) |
|
elif label_mode == "Censor Predictions": |
|
cv2.rectangle( |
|
img, |
|
(int(ann["xyxy"][0]), int(ann["xyxy"][1])), |
|
(int(ann["xyxy"][2]), int(ann["xyxy"][3])), |
|
(0, 0, 0), |
|
-1, |
|
) |
|
|
|
return annotator.result(), label_counts |
|
|
|
|
|
def detect_nsfw(input_image, conf_threshold=0.3, iou_threshold=0.45, label_mode="Draw box"): |
|
if isinstance(input_image, str): |
|
response = requests.get(input_image) |
|
image = Image.open(BytesIO(response.content)) |
|
else: |
|
image = Image.fromarray(input_image) |
|
|
|
image_np = np.array(image) |
|
if len(image_np.shape) == 2: |
|
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB) |
|
elif image_np.shape[2] == 4: |
|
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_np = resize_and_pad(image_np, imgsz) |
|
processed_image, label_counts = process_image_yolo( |
|
image_np, conf_threshold, iou_threshold, label_mode |
|
) |
|
|
|
|
|
if "head" in label_counts and len(label_counts) == 1: |
|
head_count = label_counts["head"] |
|
result_text = ( |
|
f"Detected content:\n - head: {head_count} instance(s)\nThis image appears to be safe." |
|
) |
|
elif label_counts: |
|
result_text = "Detected content:\n" |
|
for label, count in label_counts.items(): |
|
result_text += f" - {label}: {count} instance(s)\n" |
|
else: |
|
result_text = "This image appears to be safe." |
|
|
|
return result_text, processed_image |
|
|