File size: 4,621 Bytes
9e31b92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import cv2
import numpy as np
from PIL import Image
import requests
from io import BytesIO
import torch
import gradio as gr
from models.common import DetectMultiBackend, NSFWModel
from utils.torch_utils import select_device
from utils.general import (check_img_size, non_max_suppression, scale_boxes)
from utils.plots import Annotator, colors
# Load classification model
nsfw_model = NSFWModel()
# Load YOLO model
device = select_device('')
yolo_model = DetectMultiBackend('./weights/nsfw_detector_e_rok.pt', device=device, dnn=False, data=None, fp16=False)
stride, names, pt = yolo_model.stride, yolo_model.names, yolo_model.pt
imgsz = check_img_size((640, 640), s=stride)
def resize_and_pad(image, target_size):
ih, iw = image.shape[:2]
target_h, target_w = target_size
# 이미지의 가로세로 비율 계산
scale = min(target_h/ih, target_w/iw)
# 새로운 크기 계산
new_h, new_w = int(ih * scale), int(iw * scale)
# 이미지 리사이즈
resized = cv2.resize(image, (new_w, new_h))
# 패딩 계산
pad_h = (target_h - new_h) // 2
pad_w = (target_w - new_w) // 2
# 패딩 추가
padded = cv2.copyMakeBorder(resized, pad_h, target_h-new_h-pad_h, pad_w, target_w-new_w-pad_w, cv2.BORDER_CONSTANT, value=[0,0,0])
return padded
def process_image_yolo(image, conf_threshold, iou_threshold, label_mode):
# Image preprocessing
im = torch.from_numpy(image).to(device).permute(2, 0, 1)
im = im.half() if yolo_model.fp16 else im.float()
im /= 255
if len(im.shape) == 3:
im = im[None]
# Resize image
im = torch.nn.functional.interpolate(im, size=imgsz, mode='bilinear', align_corners=False)
# Inference
pred = yolo_model(im, augment=False, visualize=False)
if isinstance(pred, list):
pred = pred[0]
# NMS
pred = non_max_suppression(pred, conf_threshold, iou_threshold, None, False, max_det=1000)
# Process results
img = image.copy()
harmful_label_list = []
annotations = []
for i, det in enumerate(pred):
if len(det):
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], img.shape).round()
for *xyxy, conf, cls in reversed(det):
c = int(cls)
if c != 6:
harmful_label_list.append(c)
annotation = {
'xyxy': xyxy,
'conf': conf,
'cls': c,
'label': f"{names[c]} {conf:.2f}" if label_mode == "Draw Confidence" else f"{names[c]}"
}
annotations.append(annotation)
if 4 in harmful_label_list and 10 in harmful_label_list:
gr.Warning("Warning: This image is featuring underwear.")
elif harmful_label_list:
gr.Error("Warning: This image may contain harmful content.")
img = cv2.GaussianBlur(img, (125, 125), 0)
else:
gr.Info('This image appears to be safe.')
annotator = Annotator(img, line_width=3, example=str(names))
for ann in annotations:
if label_mode == "Draw box":
annotator.box_label(ann['xyxy'], None, color=colors(ann['cls'], True))
elif label_mode in ["Draw Label", "Draw Confidence"]:
annotator.box_label(ann['xyxy'], ann['label'], color=colors(ann['cls'], True))
elif label_mode == "Censor Predictions":
cv2.rectangle(img, (int(ann['xyxy'][0]), int(ann['xyxy'][1])), (int(ann['xyxy'][2]), int(ann['xyxy'][3])), (0, 0, 0), -1)
return annotator.result()
def detect_nsfw(input_image, detection_mode, conf_threshold=0.3, iou_threshold=0.45, label_mode="Draw box"):
if isinstance(input_image, str): # URL input
response = requests.get(input_image)
image = Image.open(BytesIO(response.content))
else: # File upload
image = Image.fromarray(input_image)
image_np = np.array(image)
if len(image_np.shape) == 2: # grayscale
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
elif image_np.shape[2] == 4: # RGBA
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
if detection_mode == "Simple Check":
result = nsfw_model.predict(image)
return result, None
else: # Detailed Analysis
image_np = resize_and_pad(image_np, imgsz) # 여기서 imgsz는 (640, 640)
processed_image = process_image_yolo(image_np, conf_threshold, iou_threshold, label_mode)
return "Detailed analysis completed. See the image for results.", processed_image
|