import cv2 import numpy as np from PIL import Image import requests from io import BytesIO import torch import gradio as gr from models.common import DetectMultiBackend, NSFWModel from utils.torch_utils import select_device from utils.general import (check_img_size, non_max_suppression, scale_boxes) from utils.plots import Annotator, colors # Load classification model nsfw_model = NSFWModel() # Load YOLO model device = select_device('') yolo_model = DetectMultiBackend('./weights/nsfw_detector_e_rok.pt', device=device, dnn=False, data=None, fp16=False) stride, names, pt = yolo_model.stride, yolo_model.names, yolo_model.pt imgsz = check_img_size((640, 640), s=stride) def resize_and_pad(image, target_size): ih, iw = image.shape[:2] target_h, target_w = target_size # 이미지의 가로세로 비율 계산 scale = min(target_h/ih, target_w/iw) # 새로운 크기 계산 new_h, new_w = int(ih * scale), int(iw * scale) # 이미지 리사이즈 resized = cv2.resize(image, (new_w, new_h)) # 패딩 계산 pad_h = (target_h - new_h) // 2 pad_w = (target_w - new_w) // 2 # 패딩 추가 padded = cv2.copyMakeBorder(resized, pad_h, target_h-new_h-pad_h, pad_w, target_w-new_w-pad_w, cv2.BORDER_CONSTANT, value=[0,0,0]) return padded def process_image_yolo(image, conf_threshold, iou_threshold, label_mode): # Image preprocessing im = torch.from_numpy(image).to(device).permute(2, 0, 1) im = im.half() if yolo_model.fp16 else im.float() im /= 255 if len(im.shape) == 3: im = im[None] # Resize image im = torch.nn.functional.interpolate(im, size=imgsz, mode='bilinear', align_corners=False) # Inference pred = yolo_model(im, augment=False, visualize=False) if isinstance(pred, list): pred = pred[0] # NMS pred = non_max_suppression(pred, conf_threshold, iou_threshold, None, False, max_det=1000) # Process results img = image.copy() harmful_label_list = [] annotations = [] for i, det in enumerate(pred): if len(det): det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], img.shape).round() for *xyxy, conf, cls in reversed(det): c = int(cls) if c != 6: harmful_label_list.append(c) annotation = { 'xyxy': xyxy, 'conf': conf, 'cls': c, 'label': f"{names[c]} {conf:.2f}" if label_mode == "Draw Confidence" else f"{names[c]}" } annotations.append(annotation) if 4 in harmful_label_list and 10 in harmful_label_list: gr.Warning("Warning: This image is featuring underwear.") elif harmful_label_list: gr.Error("Warning: This image may contain harmful content.") img = cv2.GaussianBlur(img, (125, 125), 0) else: gr.Info('This image appears to be safe.') annotator = Annotator(img, line_width=3, example=str(names)) for ann in annotations: if label_mode == "Draw box": annotator.box_label(ann['xyxy'], None, color=colors(ann['cls'], True)) elif label_mode in ["Draw Label", "Draw Confidence"]: annotator.box_label(ann['xyxy'], ann['label'], color=colors(ann['cls'], True)) elif label_mode == "Censor Predictions": cv2.rectangle(img, (int(ann['xyxy'][0]), int(ann['xyxy'][1])), (int(ann['xyxy'][2]), int(ann['xyxy'][3])), (0, 0, 0), -1) return annotator.result() def detect_nsfw(input_image, detection_mode, conf_threshold=0.3, iou_threshold=0.45, label_mode="Draw box"): if isinstance(input_image, str): # URL input response = requests.get(input_image) image = Image.open(BytesIO(response.content)) else: # File upload image = Image.fromarray(input_image) image_np = np.array(image) if len(image_np.shape) == 2: # grayscale image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB) elif image_np.shape[2] == 4: # RGBA image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB) if detection_mode == "Simple Check": result = nsfw_model.predict(image) return result, None else: # Detailed Analysis image_np = resize_and_pad(image_np, imgsz) # 여기서 imgsz는 (640, 640) processed_image = process_image_yolo(image_np, conf_threshold, iou_threshold, label_mode) return "Detailed analysis completed. See the image for results.", processed_image