import cv2 import numpy as np from PIL import Image import requests from io import BytesIO import torch import gradio as gr from models.common import DetectMultiBackend # , NSFWModel from utils.torch_utils import select_device from utils.general import check_img_size, non_max_suppression, scale_boxes from utils.plots import Annotator, colors from config.settings import DETECT_MODEL_PATH # # Load classification model # nsfw_model = NSFWModel() # Load YOLO model device = select_device("") yolo_model = DetectMultiBackend(DETECT_MODEL_PATH, device=device, dnn=False, data=None, fp16=False) stride, names, pt = yolo_model.stride, yolo_model.names, yolo_model.pt imgsz = check_img_size((640, 640), s=stride) def resize_and_pad(image, target_size): ih, iw = image.shape[:2] target_h, target_w = target_size # 이미지의 가로세로 비율 계산 scale = min(target_h / ih, target_w / iw) # 새로운 크기 계산 new_h, new_w = int(ih * scale), int(iw * scale) # 이미지 리사이즈 resized = cv2.resize(image, (new_w, new_h)) # 패딩 계산 pad_h = (target_h - new_h) // 2 pad_w = (target_w - new_w) // 2 # 패딩 추가 padded = cv2.copyMakeBorder( resized, pad_h, target_h - new_h - pad_h, pad_w, target_w - new_w - pad_w, cv2.BORDER_CONSTANT, value=[0, 0, 0], ) return padded def process_image_yolo(image, conf_threshold, iou_threshold, label_mode): # Image preprocessing im = torch.from_numpy(image).to(device).permute(2, 0, 1) im = im.half() if yolo_model.fp16 else im.float() im /= 255 if len(im.shape) == 3: im = im[None] # Resize image im = torch.nn.functional.interpolate(im, size=imgsz, mode="bilinear", align_corners=False) # Inference pred = yolo_model(im, augment=False, visualize=False) if isinstance(pred, list): pred = pred[0] # NMS pred = non_max_suppression(pred, conf_threshold, iou_threshold, None, False, max_det=1000) # Process results img = image.copy() harmful_label_list = [] annotations = [] label_counts = {} for i, det in enumerate(pred): if len(det): det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], img.shape).round() for *xyxy, conf, cls in reversed(det): c = int(cls) harmful_label_list.append(c) label_name = names[c] label_counts[label_name] = label_counts.get(label_name, 0) + 1 annotation = { "xyxy": xyxy, "conf": conf, "cls": c, "label": ( f"{names[c]} {conf:.2f}" if label_mode == "Draw Confidence" else f"{names[c]}" ), } annotations.append(annotation) # Annotate image annotator = Annotator(img, line_width=3, example=str(names)) for ann in annotations: if label_mode == "Draw box": annotator.box_label(ann["xyxy"], None, color=colors(ann["cls"], True)) elif label_mode in ["Draw Label", "Draw Confidence"]: annotator.box_label(ann["xyxy"], ann["label"], color=colors(ann["cls"], True)) elif label_mode == "Censor Predictions": cv2.rectangle( img, (int(ann["xyxy"][0]), int(ann["xyxy"][1])), (int(ann["xyxy"][2]), int(ann["xyxy"][3])), (0, 0, 0), -1, ) return annotator.result(), label_counts def detect_nsfw(input_image, conf_threshold=0.3, iou_threshold=0.45, label_mode="Draw box"): if isinstance(input_image, str): # URL input response = requests.get(input_image) image = Image.open(BytesIO(response.content)) else: # File upload image = Image.fromarray(input_image) image_np = np.array(image) if len(image_np.shape) == 2: # grayscale image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB) elif image_np.shape[2] == 4: # RGBA image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB) # if detection_mode == "Simple Check": # result = nsfw_model.predict(image) # return result, None # else: # Detailed Analysis # image_np = resize_and_pad(image_np, imgsz) # 여기서 imgsz는 (640, 640) # processed_image = process_image_yolo(image_np, conf_threshold, iou_threshold, label_mode) # return "Detailed analysis completed. See the image for results.", processed_image # Resize image and process with YOLO image_np = resize_and_pad(image_np, imgsz) # 여기서 imgsz는 (640, 640) processed_image, label_counts = process_image_yolo( image_np, conf_threshold, iou_threshold, label_mode ) # Construct detailed result text if "head" in label_counts and len(label_counts) == 1: head_count = label_counts["head"] result_text = ( f"Detected content:\n - head: {head_count} instance(s)\nThis image appears to be safe." ) elif label_counts: result_text = "Detected content:\n" for label, count in label_counts.items(): result_text += f" - {label}: {count} instance(s)\n" else: result_text = "This image appears to be safe." return result_text, processed_image