LearningnRunning commited on
Commit
9e31b92
·
1 Parent(s): a8f267e

FEAT SImple version

Browse files
Files changed (2) hide show
  1. app.py +2 -3
  2. utils/data_processing.py +126 -0
app.py CHANGED
@@ -18,11 +18,10 @@ from models.common import DetectMultiBackend
18
  from utils.general import (check_img_size, non_max_suppression, scale_boxes)
19
  from utils.plots import Annotator, colors
20
  from utils.torch_utils import select_device
21
- from config.settings import MODEL_PATH
22
 
23
  # YOLOv9 모델 로드
24
  device = select_device('')
25
- model = DetectMultiBackend(MODEL_PATH, device=device, dnn=False, data=None, fp16=False)
26
  stride, names, pt = model.stride, model.names, model.pt
27
  imgsz = check_img_size((640, 640), s=stride) # check image size
28
 
@@ -114,7 +113,7 @@ demo = gr.Interface(
114
  fn=detect_nsfw,
115
  inputs=[
116
  gr.Image(type="numpy", label="Upload an image or enter a URL"),
117
- gr.Slider(0, 1, value=0.45, label="Confidence Threshold"),
118
  gr.Slider(0, 1, value=0.45, label="Overlap Threshold"),
119
  gr.Dropdown(["Draw box", "Draw Label", "Draw Confidence", "Censor Predictions"], label="Label Display Mode", value="Draw box")
120
  ],
 
18
  from utils.general import (check_img_size, non_max_suppression, scale_boxes)
19
  from utils.plots import Annotator, colors
20
  from utils.torch_utils import select_device
 
21
 
22
  # YOLOv9 모델 로드
23
  device = select_device('')
24
+ model = DetectMultiBackend('./weights/nsfw_detector_e_rok.pt', device=device, dnn=False, data=None, fp16=False)
25
  stride, names, pt = model.stride, model.names, model.pt
26
  imgsz = check_img_size((640, 640), s=stride) # check image size
27
 
 
113
  fn=detect_nsfw,
114
  inputs=[
115
  gr.Image(type="numpy", label="Upload an image or enter a URL"),
116
+ gr.Slider(0, 1, value=0.3, label="Confidence Threshold"),
117
  gr.Slider(0, 1, value=0.45, label="Overlap Threshold"),
118
  gr.Dropdown(["Draw box", "Draw Label", "Draw Confidence", "Censor Predictions"], label="Label Display Mode", value="Draw box")
119
  ],
utils/data_processing.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ import requests
5
+ from io import BytesIO
6
+ import torch
7
+ import gradio as gr
8
+
9
+ from models.common import DetectMultiBackend, NSFWModel
10
+ from utils.torch_utils import select_device
11
+ from utils.general import (check_img_size, non_max_suppression, scale_boxes)
12
+ from utils.plots import Annotator, colors
13
+
14
+
15
+ # Load classification model
16
+ nsfw_model = NSFWModel()
17
+
18
+ # Load YOLO model
19
+ device = select_device('')
20
+ yolo_model = DetectMultiBackend('./weights/nsfw_detector_e_rok.pt', device=device, dnn=False, data=None, fp16=False)
21
+ stride, names, pt = yolo_model.stride, yolo_model.names, yolo_model.pt
22
+ imgsz = check_img_size((640, 640), s=stride)
23
+
24
+ def resize_and_pad(image, target_size):
25
+ ih, iw = image.shape[:2]
26
+ target_h, target_w = target_size
27
+
28
+ # 이미지의 가로세로 비율 계산
29
+ scale = min(target_h/ih, target_w/iw)
30
+
31
+ # 새로운 크기 계산
32
+ new_h, new_w = int(ih * scale), int(iw * scale)
33
+
34
+ # 이미지 리사이즈
35
+ resized = cv2.resize(image, (new_w, new_h))
36
+
37
+ # 패딩 계산
38
+ pad_h = (target_h - new_h) // 2
39
+ pad_w = (target_w - new_w) // 2
40
+
41
+ # 패딩 추가
42
+ padded = cv2.copyMakeBorder(resized, pad_h, target_h-new_h-pad_h, pad_w, target_w-new_w-pad_w, cv2.BORDER_CONSTANT, value=[0,0,0])
43
+
44
+ return padded
45
+
46
+ def process_image_yolo(image, conf_threshold, iou_threshold, label_mode):
47
+ # Image preprocessing
48
+ im = torch.from_numpy(image).to(device).permute(2, 0, 1)
49
+ im = im.half() if yolo_model.fp16 else im.float()
50
+ im /= 255
51
+ if len(im.shape) == 3:
52
+ im = im[None]
53
+
54
+ # Resize image
55
+ im = torch.nn.functional.interpolate(im, size=imgsz, mode='bilinear', align_corners=False)
56
+
57
+ # Inference
58
+ pred = yolo_model(im, augment=False, visualize=False)
59
+ if isinstance(pred, list):
60
+ pred = pred[0]
61
+
62
+ # NMS
63
+ pred = non_max_suppression(pred, conf_threshold, iou_threshold, None, False, max_det=1000)
64
+
65
+ # Process results
66
+ img = image.copy()
67
+ harmful_label_list = []
68
+ annotations = []
69
+
70
+ for i, det in enumerate(pred):
71
+ if len(det):
72
+ det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], img.shape).round()
73
+
74
+ for *xyxy, conf, cls in reversed(det):
75
+ c = int(cls)
76
+ if c != 6:
77
+ harmful_label_list.append(c)
78
+
79
+ annotation = {
80
+ 'xyxy': xyxy,
81
+ 'conf': conf,
82
+ 'cls': c,
83
+ 'label': f"{names[c]} {conf:.2f}" if label_mode == "Draw Confidence" else f"{names[c]}"
84
+ }
85
+ annotations.append(annotation)
86
+
87
+ if 4 in harmful_label_list and 10 in harmful_label_list:
88
+ gr.Warning("Warning: This image is featuring underwear.")
89
+ elif harmful_label_list:
90
+ gr.Error("Warning: This image may contain harmful content.")
91
+ img = cv2.GaussianBlur(img, (125, 125), 0)
92
+ else:
93
+ gr.Info('This image appears to be safe.')
94
+
95
+ annotator = Annotator(img, line_width=3, example=str(names))
96
+
97
+ for ann in annotations:
98
+ if label_mode == "Draw box":
99
+ annotator.box_label(ann['xyxy'], None, color=colors(ann['cls'], True))
100
+ elif label_mode in ["Draw Label", "Draw Confidence"]:
101
+ annotator.box_label(ann['xyxy'], ann['label'], color=colors(ann['cls'], True))
102
+ elif label_mode == "Censor Predictions":
103
+ cv2.rectangle(img, (int(ann['xyxy'][0]), int(ann['xyxy'][1])), (int(ann['xyxy'][2]), int(ann['xyxy'][3])), (0, 0, 0), -1)
104
+
105
+ return annotator.result()
106
+
107
+ def detect_nsfw(input_image, detection_mode, conf_threshold=0.3, iou_threshold=0.45, label_mode="Draw box"):
108
+ if isinstance(input_image, str): # URL input
109
+ response = requests.get(input_image)
110
+ image = Image.open(BytesIO(response.content))
111
+ else: # File upload
112
+ image = Image.fromarray(input_image)
113
+
114
+ image_np = np.array(image)
115
+ if len(image_np.shape) == 2: # grayscale
116
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
117
+ elif image_np.shape[2] == 4: # RGBA
118
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
119
+
120
+ if detection_mode == "Simple Check":
121
+ result = nsfw_model.predict(image)
122
+ return result, None
123
+ else: # Detailed Analysis
124
+ image_np = resize_and_pad(image_np, imgsz) # 여기서 imgsz는 (640, 640)
125
+ processed_image = process_image_yolo(image_np, conf_threshold, iou_threshold, label_mode)
126
+ return "Detailed analysis completed. See the image for results.", processed_image