File size: 5,446 Bytes
9e31b92 82654de 9e31b92 82654de 9e31b92 1de3353 9e31b92 82654de 9e31b92 82654de 3a53eae 9e31b92 82654de 9e31b92 82654de 9e31b92 82654de 9e31b92 82654de 9e31b92 82654de 9e31b92 82654de 9e31b92 82654de 9e31b92 82654de 9e31b92 82654de 9e31b92 82654de 9e31b92 83413d2 82654de 9e31b92 82654de 9e31b92 83413d2 82654de 9e31b92 82654de 9e31b92 82654de 83413d2 9e31b92 82654de 9e31b92 82654de 9e31b92 82654de 9e31b92 82654de 9e31b92 83413d2 9e31b92 82654de 9e31b92 82654de 9e31b92 82654de 83413d2 82654de 83413d2 b3673c2 043f3b8 83413d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import cv2
import numpy as np
from PIL import Image
import requests
from io import BytesIO
import torch
import gradio as gr
from models.common import DetectMultiBackend # , NSFWModel
from utils.torch_utils import select_device
from utils.general import check_img_size, non_max_suppression, scale_boxes
from utils.plots import Annotator, colors
from config.settings import DETECT_MODEL_PATH
# # Load classification model
# nsfw_model = NSFWModel()
# Load YOLO model
device = select_device("")
yolo_model = DetectMultiBackend(DETECT_MODEL_PATH, device=device, dnn=False, data=None, fp16=False)
stride, names, pt = yolo_model.stride, yolo_model.names, yolo_model.pt
imgsz = check_img_size((640, 640), s=stride)
def resize_and_pad(image, target_size):
ih, iw = image.shape[:2]
target_h, target_w = target_size
# 이미지의 가로세로 비율 계산
scale = min(target_h / ih, target_w / iw)
# 새로운 크기 계산
new_h, new_w = int(ih * scale), int(iw * scale)
# 이미지 리사이즈
resized = cv2.resize(image, (new_w, new_h))
# 패딩 계산
pad_h = (target_h - new_h) // 2
pad_w = (target_w - new_w) // 2
# 패딩 추가
padded = cv2.copyMakeBorder(
resized,
pad_h,
target_h - new_h - pad_h,
pad_w,
target_w - new_w - pad_w,
cv2.BORDER_CONSTANT,
value=[0, 0, 0],
)
return padded
def process_image_yolo(image, conf_threshold, iou_threshold, label_mode):
# Image preprocessing
im = torch.from_numpy(image).to(device).permute(2, 0, 1)
im = im.half() if yolo_model.fp16 else im.float()
im /= 255
if len(im.shape) == 3:
im = im[None]
# Resize image
im = torch.nn.functional.interpolate(im, size=imgsz, mode="bilinear", align_corners=False)
# Inference
pred = yolo_model(im, augment=False, visualize=False)
if isinstance(pred, list):
pred = pred[0]
# NMS
pred = non_max_suppression(pred, conf_threshold, iou_threshold, None, False, max_det=1000)
# Process results
img = image.copy()
harmful_label_list = []
annotations = []
label_counts = {}
for i, det in enumerate(pred):
if len(det):
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], img.shape).round()
for *xyxy, conf, cls in reversed(det):
c = int(cls)
harmful_label_list.append(c)
label_name = names[c]
label_counts[label_name] = label_counts.get(label_name, 0) + 1
annotation = {
"xyxy": xyxy,
"conf": conf,
"cls": c,
"label": (
f"{names[c]} {conf:.2f}"
if label_mode == "Draw Confidence"
else f"{names[c]}"
),
}
annotations.append(annotation)
# Annotate image
annotator = Annotator(img, line_width=3, example=str(names))
for ann in annotations:
if label_mode == "Draw box":
annotator.box_label(ann["xyxy"], None, color=colors(ann["cls"], True))
elif label_mode in ["Draw Label", "Draw Confidence"]:
annotator.box_label(ann["xyxy"], ann["label"], color=colors(ann["cls"], True))
elif label_mode == "Censor Predictions":
cv2.rectangle(
img,
(int(ann["xyxy"][0]), int(ann["xyxy"][1])),
(int(ann["xyxy"][2]), int(ann["xyxy"][3])),
(0, 0, 0),
-1,
)
return annotator.result(), label_counts
def detect_nsfw(input_image, conf_threshold=0.3, iou_threshold=0.45, label_mode="Draw box"):
if isinstance(input_image, str): # URL input
response = requests.get(input_image)
image = Image.open(BytesIO(response.content))
else: # File upload
image = Image.fromarray(input_image)
image_np = np.array(image)
if len(image_np.shape) == 2: # grayscale
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
elif image_np.shape[2] == 4: # RGBA
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
# if detection_mode == "Simple Check":
# result = nsfw_model.predict(image)
# return result, None
# else: # Detailed Analysis
# image_np = resize_and_pad(image_np, imgsz) # 여기서 imgsz는 (640, 640)
# processed_image = process_image_yolo(image_np, conf_threshold, iou_threshold, label_mode)
# return "Detailed analysis completed. See the image for results.", processed_image
# Resize image and process with YOLO
image_np = resize_and_pad(image_np, imgsz) # 여기서 imgsz는 (640, 640)
processed_image, label_counts = process_image_yolo(
image_np, conf_threshold, iou_threshold, label_mode
)
# Construct detailed result text
if "head" in label_counts and len(label_counts) == 1:
head_count = label_counts["head"]
result_text = (
f"Detected content:\n - head: {head_count} instance(s)\nThis image appears to be safe."
)
elif label_counts:
result_text = "Detected content:\n"
for label, count in label_counts.items():
result_text += f" - {label}: {count} instance(s)\n"
else:
result_text = "This image appears to be safe."
return result_text, processed_image
|