File size: 5,446 Bytes
9e31b92
 
 
 
 
 
 
 
82654de
9e31b92
82654de
9e31b92
1de3353
9e31b92
82654de
 
9e31b92
 
82654de
3a53eae
9e31b92
 
 
82654de
9e31b92
 
 
 
 
82654de
 
9e31b92
 
82654de
9e31b92
 
82654de
9e31b92
 
 
82654de
9e31b92
82654de
 
 
 
 
 
 
 
 
 
9e31b92
 
82654de
9e31b92
 
 
 
 
 
 
82654de
9e31b92
82654de
 
9e31b92
 
 
 
82654de
9e31b92
 
 
 
 
 
 
83413d2
82654de
9e31b92
 
 
82654de
9e31b92
 
83413d2
 
 
82654de
9e31b92
82654de
 
 
 
 
 
 
 
9e31b92
 
82654de
83413d2
9e31b92
82654de
9e31b92
 
82654de
9e31b92
82654de
9e31b92
82654de
 
 
 
 
 
 
9e31b92
83413d2
9e31b92
82654de
 
9e31b92
 
 
 
 
82654de
9e31b92
 
 
 
 
82654de
 
 
 
 
 
 
 
83413d2
 
82654de
83413d2
 
 
 
 
b3673c2
043f3b8
 
 
 
 
83413d2
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import cv2
import numpy as np
from PIL import Image
import requests
from io import BytesIO
import torch
import gradio as gr

from models.common import DetectMultiBackend  # , NSFWModel
from utils.torch_utils import select_device
from utils.general import check_img_size, non_max_suppression, scale_boxes
from utils.plots import Annotator, colors
from config.settings import DETECT_MODEL_PATH

# # Load classification model
# nsfw_model = NSFWModel()

# Load YOLO model
device = select_device("")
yolo_model = DetectMultiBackend(DETECT_MODEL_PATH, device=device, dnn=False, data=None, fp16=False)
stride, names, pt = yolo_model.stride, yolo_model.names, yolo_model.pt
imgsz = check_img_size((640, 640), s=stride)


def resize_and_pad(image, target_size):
    ih, iw = image.shape[:2]
    target_h, target_w = target_size

    # 이미지의 가로세로 비율 계산
    scale = min(target_h / ih, target_w / iw)

    # 새로운 크기 계산
    new_h, new_w = int(ih * scale), int(iw * scale)

    # 이미지 리사이즈
    resized = cv2.resize(image, (new_w, new_h))

    # 패딩 계산
    pad_h = (target_h - new_h) // 2
    pad_w = (target_w - new_w) // 2

    # 패딩 추가
    padded = cv2.copyMakeBorder(
        resized,
        pad_h,
        target_h - new_h - pad_h,
        pad_w,
        target_w - new_w - pad_w,
        cv2.BORDER_CONSTANT,
        value=[0, 0, 0],
    )

    return padded


def process_image_yolo(image, conf_threshold, iou_threshold, label_mode):
    # Image preprocessing
    im = torch.from_numpy(image).to(device).permute(2, 0, 1)
    im = im.half() if yolo_model.fp16 else im.float()
    im /= 255
    if len(im.shape) == 3:
        im = im[None]

    # Resize image
    im = torch.nn.functional.interpolate(im, size=imgsz, mode="bilinear", align_corners=False)

    # Inference
    pred = yolo_model(im, augment=False, visualize=False)
    if isinstance(pred, list):
        pred = pred[0]

    # NMS
    pred = non_max_suppression(pred, conf_threshold, iou_threshold, None, False, max_det=1000)

    # Process results
    img = image.copy()
    harmful_label_list = []
    annotations = []
    label_counts = {}

    for i, det in enumerate(pred):
        if len(det):
            det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], img.shape).round()

            for *xyxy, conf, cls in reversed(det):
                c = int(cls)
                harmful_label_list.append(c)
                label_name = names[c]
                label_counts[label_name] = label_counts.get(label_name, 0) + 1

                annotation = {
                    "xyxy": xyxy,
                    "conf": conf,
                    "cls": c,
                    "label": (
                        f"{names[c]} {conf:.2f}"
                        if label_mode == "Draw Confidence"
                        else f"{names[c]}"
                    ),
                }
                annotations.append(annotation)

    # Annotate image
    annotator = Annotator(img, line_width=3, example=str(names))

    for ann in annotations:
        if label_mode == "Draw box":
            annotator.box_label(ann["xyxy"], None, color=colors(ann["cls"], True))
        elif label_mode in ["Draw Label", "Draw Confidence"]:
            annotator.box_label(ann["xyxy"], ann["label"], color=colors(ann["cls"], True))
        elif label_mode == "Censor Predictions":
            cv2.rectangle(
                img,
                (int(ann["xyxy"][0]), int(ann["xyxy"][1])),
                (int(ann["xyxy"][2]), int(ann["xyxy"][3])),
                (0, 0, 0),
                -1,
            )

    return annotator.result(), label_counts


def detect_nsfw(input_image, conf_threshold=0.3, iou_threshold=0.45, label_mode="Draw box"):
    if isinstance(input_image, str):  # URL input
        response = requests.get(input_image)
        image = Image.open(BytesIO(response.content))
    else:  # File upload
        image = Image.fromarray(input_image)

    image_np = np.array(image)
    if len(image_np.shape) == 2:  # grayscale
        image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
    elif image_np.shape[2] == 4:  # RGBA
        image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)

    # if detection_mode == "Simple Check":
    #     result = nsfw_model.predict(image)
    #     return result, None
    # else:  # Detailed Analysis
    #     image_np = resize_and_pad(image_np, imgsz)  # 여기서 imgsz는 (640, 640)
    #     processed_image = process_image_yolo(image_np, conf_threshold, iou_threshold, label_mode)
    #     return "Detailed analysis completed. See the image for results.", processed_image

    # Resize image and process with YOLO
    image_np = resize_and_pad(image_np, imgsz)  # 여기서 imgsz는 (640, 640)
    processed_image, label_counts = process_image_yolo(
        image_np, conf_threshold, iou_threshold, label_mode
    )

    # Construct detailed result text
    if "head" in label_counts and len(label_counts) == 1:
        head_count = label_counts["head"]
        result_text = (
            f"Detected content:\n - head: {head_count} instance(s)\nThis image appears to be safe."
        )
    elif label_counts:
        result_text = "Detected content:\n"
        for label, count in label_counts.items():
            result_text += f" - {label}: {count} instance(s)\n"
    else:
        result_text = "This image appears to be safe."

    return result_text, processed_image