LearningnRunning commited on
Commit
82654de
·
1 Parent(s): 1de3353

feat Only detail detect model

Browse files
Files changed (4) hide show
  1. app.py +32 -25
  2. config/settings.py +3 -2
  3. models/common.py +14 -14
  4. utils/data_processing.py +64 -40
app.py CHANGED
@@ -1,9 +1,10 @@
1
- import gradio as gr
2
  import sys
3
  from pathlib import Path
4
- import os
5
  from utils.data_processing import detect_nsfw
6
- # Import YOLO-related modules
 
7
  FILE = Path(__file__).resolve()
8
  ROOT = FILE.parents[0]
9
  if str(ROOT) not in sys.path:
@@ -11,37 +12,43 @@ if str(ROOT) not in sys.path:
11
  ROOT = Path(os.path.relpath(ROOT, Path.cwd()))
12
 
13
 
14
-
15
-
16
-
17
- # Gradio interface
18
  with gr.Blocks() as demo:
19
- gr.Markdown("# NSFW Content Detection")
20
- with gr.Row():
21
- detection_mode = gr.Radio(["Simple Check", "Detailed Analysis"], label="Detection Mode", value="Simple Check")
22
  with gr.Row():
23
- conf_threshold = gr.Slider(0, 1, value=0.3, label="Confidence Threshold", visible=False)
24
- iou_threshold = gr.Slider(0, 1, value=0.45, label="Overlap Threshold", visible=False)
25
- label_mode = gr.Dropdown(["Draw box", "Draw Label", "Draw Confidence", "Censor Predictions"], label="Label Display Mode", value="Draw box", visible=False)
26
-
 
 
 
 
 
27
  with gr.Row():
28
  input_image = gr.Image(type="numpy", label="Upload an image or enter a URL")
29
  output_text = gr.Textbox(label="Detection Result")
30
  with gr.Row():
31
- output_image = gr.Image(type="numpy", label="Processed Image (for detailed analysis)", visible=False)
32
 
 
33
  detect_button = gr.Button("Detect")
34
-
35
- def update_visibility(mode):
36
- return [gr.update(visible=(mode == "Detailed Analysis"))] * 4
37
-
38
- detection_mode.change(update_visibility, inputs=[detection_mode], outputs=[conf_threshold, iou_threshold, label_mode, output_image])
39
-
 
 
40
  detect_button.click(
41
- detect_nsfw,
42
- inputs=[input_image, detection_mode, conf_threshold, iou_threshold, label_mode],
43
- outputs=[output_text, output_image]
44
  )
45
 
 
46
  if __name__ == "__main__":
47
- demo.launch(server_name="0.0.0.0")
 
1
+ import os
2
  import sys
3
  from pathlib import Path
4
+ import gradio as gr
5
  from utils.data_processing import detect_nsfw
6
+
7
+ # YOLO-related module path setup
8
  FILE = Path(__file__).resolve()
9
  ROOT = FILE.parents[0]
10
  if str(ROOT) not in sys.path:
 
12
  ROOT = Path(os.path.relpath(ROOT, Path.cwd()))
13
 
14
 
15
+ # Define the Gradio interface
 
 
 
16
  with gr.Blocks() as demo:
17
+ gr.Markdown("# NSFW Content Detection - Detailed Analysis")
18
+
19
+ # Advanced parameters for Detailed Analysis
20
  with gr.Row():
21
+ conf_threshold = gr.Slider(0, 1, value=0.3, label="Confidence Threshold")
22
+ iou_threshold = gr.Slider(0, 1, value=0.45, label="Overlap Threshold")
23
+ label_mode = gr.Dropdown(
24
+ ["Draw box", "Draw Label", "Draw Confidence", "Censor Predictions"],
25
+ label="Label Display Mode",
26
+ value="Draw box",
27
+ )
28
+
29
+ # Input and output components
30
  with gr.Row():
31
  input_image = gr.Image(type="numpy", label="Upload an image or enter a URL")
32
  output_text = gr.Textbox(label="Detection Result")
33
  with gr.Row():
34
+ output_image = gr.Image(type="numpy", label="Processed Image (for detailed analysis)")
35
 
36
+ # Detection button
37
  detect_button = gr.Button("Detect")
38
+
39
+ # Connect detection button to the detect_nsfw function
40
+ def safe_detect_nsfw(image, conf, iou, label):
41
+ try:
42
+ return detect_nsfw(image, "Detailed Analysis", conf, iou, label)
43
+ except Exception as e:
44
+ return f"Error during detection: {e}", None
45
+
46
  detect_button.click(
47
+ safe_detect_nsfw,
48
+ inputs=[input_image, conf_threshold, iou_threshold, label_mode],
49
+ outputs=[output_text, output_image],
50
  )
51
 
52
+ # Launch the Gradio app
53
  if __name__ == "__main__":
54
+ demo.launch(server_name="0.0.0.0")
config/settings.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
 
2
  # 프로젝트 루트 디렉토리 경로 얻기
3
  BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
4
 
5
  # MODEL_PATH를 절대 경로로 설정
6
- DETECT_MODEL_PATH = os.path.join(BASE_DIR, 'weights', 'yolov9_c_nsfw.pt')
7
- CLASSIFICATION_MODEL_PATH = "Falconsai/nsfw_image_detection"
 
1
  import os
2
+
3
  # 프로젝트 루트 디렉토리 경로 얻기
4
  BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
5
 
6
  # MODEL_PATH를 절대 경로로 설정
7
+ DETECT_MODEL_PATH = os.path.join(BASE_DIR, "weights", "yolov9_c_nsfw.pt")
8
+ # CLASSIFICATION_MODEL_PATH = "Falconsai/nsfw_image_detection"
models/common.py CHANGED
@@ -30,7 +30,7 @@ from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suff
30
  xywh2xyxy, xyxy2xywh, yaml_load)
31
  from utils.plots import Annotator, colors, save_one_box
32
  from utils.torch_utils import copy_attr, smart_inference_mode
33
- from config.settings import CLASSIFICATION_MODEL_PATH
34
 
35
  def autopad(k, p=None, d=1): # kernel, padding, dilation
36
  # Pad to 'same' shape outputs
@@ -1200,17 +1200,17 @@ class Classify(nn.Module):
1200
  x = torch.cat(x, 1)
1201
  return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
1202
 
1203
- class NSFWModel:
1204
- def __init__(self):
1205
- self.model = AutoModelForImageClassification.from_pretrained(CLASSIFICATION_MODEL_PATH)
1206
- self.processor = ViTImageProcessor.from_pretrained(CLASSIFICATION_MODEL_PATH)
1207
- self.id2label = self.model.config.id2label
1208
-
1209
- def predict(self, image: Image.Image) -> str:
1210
- with torch.no_grad():
1211
- inputs = self.processor(images=image, return_tensors="pt")
1212
- outputs = self.model(**inputs)
1213
- logits = outputs.logits
1214
- predicted_label = logits.argmax(-1).item()
1215
 
1216
- return self.id2label[predicted_label]
 
30
  xywh2xyxy, xyxy2xywh, yaml_load)
31
  from utils.plots import Annotator, colors, save_one_box
32
  from utils.torch_utils import copy_attr, smart_inference_mode
33
+ # from config.settings import CLASSIFICATION_MODEL_PATH
34
 
35
  def autopad(k, p=None, d=1): # kernel, padding, dilation
36
  # Pad to 'same' shape outputs
 
1200
  x = torch.cat(x, 1)
1201
  return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
1202
 
1203
+ # class NSFWModel:
1204
+ # def __init__(self):
1205
+ # self.model = AutoModelForImageClassification.from_pretrained(CLASSIFICATION_MODEL_PATH, local_files_only=True)
1206
+ # self.processor = ViTImageProcessor.from_pretrained(CLASSIFICATION_MODEL_PATH, local_files_only=True)
1207
+ # self.id2label = self.model.config.id2label
1208
+
1209
+ # def predict(self, image: Image.Image) -> str:
1210
+ # with torch.no_grad():
1211
+ # inputs = self.processor(images=image, return_tensors="pt")
1212
+ # outputs = self.model(**inputs)
1213
+ # logits = outputs.logits
1214
+ # predicted_label = logits.argmax(-1).item()
1215
 
1216
+ # return self.id2label[predicted_label]
utils/data_processing.py CHANGED
@@ -6,43 +6,53 @@ from io import BytesIO
6
  import torch
7
  import gradio as gr
8
 
9
- from models.common import DetectMultiBackend, NSFWModel
10
  from utils.torch_utils import select_device
11
- from utils.general import (check_img_size, non_max_suppression, scale_boxes)
12
  from utils.plots import Annotator, colors
13
  from config.settings import DETECT_MODEL_PATH
14
 
15
- # Load classification model
16
- nsfw_model = NSFWModel()
17
 
18
  # Load YOLO model
19
- device = select_device('')
20
  yolo_model = DetectMultiBackend(DETECT_MODEL_PATH, device=device, dnn=False, data=None, fp16=False)
21
  stride, names, pt = yolo_model.stride, yolo_model.names, yolo_model.pt
22
  imgsz = check_img_size((640, 640), s=stride)
23
 
 
24
  def resize_and_pad(image, target_size):
25
  ih, iw = image.shape[:2]
26
  target_h, target_w = target_size
27
 
28
  # 이미지의 가로세로 비율 계산
29
- scale = min(target_h/ih, target_w/iw)
30
-
31
  # 새로운 크기 계산
32
  new_h, new_w = int(ih * scale), int(iw * scale)
33
-
34
  # 이미지 리사이즈
35
  resized = cv2.resize(image, (new_w, new_h))
36
-
37
  # 패딩 계산
38
  pad_h = (target_h - new_h) // 2
39
  pad_w = (target_w - new_w) // 2
40
-
41
  # 패딩 추가
42
- padded = cv2.copyMakeBorder(resized, pad_h, target_h-new_h-pad_h, pad_w, target_w-new_w-pad_w, cv2.BORDER_CONSTANT, value=[0,0,0])
43
-
 
 
 
 
 
 
 
 
44
  return padded
45
 
 
46
  def process_image_yolo(image, conf_threshold, iou_threshold, label_mode):
47
  # Image preprocessing
48
  im = torch.from_numpy(image).to(device).permute(2, 0, 1)
@@ -50,15 +60,15 @@ def process_image_yolo(image, conf_threshold, iou_threshold, label_mode):
50
  im /= 255
51
  if len(im.shape) == 3:
52
  im = im[None]
53
-
54
  # Resize image
55
- im = torch.nn.functional.interpolate(im, size=imgsz, mode='bilinear', align_corners=False)
56
-
57
  # Inference
58
  pred = yolo_model(im, augment=False, visualize=False)
59
  if isinstance(pred, list):
60
  pred = pred[0]
61
-
62
  # NMS
63
  pred = non_max_suppression(pred, conf_threshold, iou_threshold, None, False, max_det=1000)
64
 
@@ -66,61 +76,75 @@ def process_image_yolo(image, conf_threshold, iou_threshold, label_mode):
66
  img = image.copy()
67
  harmful_label_list = []
68
  annotations = []
69
-
70
  for i, det in enumerate(pred):
71
  if len(det):
72
  det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], img.shape).round()
73
-
74
  for *xyxy, conf, cls in reversed(det):
75
  c = int(cls)
76
  if c != 6:
77
  harmful_label_list.append(c)
78
-
79
  annotation = {
80
- 'xyxy': xyxy,
81
- 'conf': conf,
82
- 'cls': c,
83
- 'label': f"{names[c]} {conf:.2f}" if label_mode == "Draw Confidence" else f"{names[c]}"
 
 
 
 
84
  }
85
  annotations.append(annotation)
86
-
87
  if 4 in harmful_label_list and 10 in harmful_label_list:
88
  gr.Warning("Warning: This image is featuring underwear.")
89
  elif harmful_label_list:
90
  gr.Error("Warning: This image may contain harmful content.")
91
  img = cv2.GaussianBlur(img, (125, 125), 0)
92
  else:
93
- gr.Info('This image appears to be safe.')
94
-
95
  annotator = Annotator(img, line_width=3, example=str(names))
96
-
97
  for ann in annotations:
98
  if label_mode == "Draw box":
99
- annotator.box_label(ann['xyxy'], None, color=colors(ann['cls'], True))
100
  elif label_mode in ["Draw Label", "Draw Confidence"]:
101
- annotator.box_label(ann['xyxy'], ann['label'], color=colors(ann['cls'], True))
102
  elif label_mode == "Censor Predictions":
103
- cv2.rectangle(img, (int(ann['xyxy'][0]), int(ann['xyxy'][1])), (int(ann['xyxy'][2]), int(ann['xyxy'][3])), (0, 0, 0), -1)
 
 
 
 
 
 
104
 
105
  return annotator.result()
106
 
107
- def detect_nsfw(input_image, detection_mode, conf_threshold=0.3, iou_threshold=0.45, label_mode="Draw box"):
 
108
  if isinstance(input_image, str): # URL input
109
  response = requests.get(input_image)
110
  image = Image.open(BytesIO(response.content))
111
  else: # File upload
112
  image = Image.fromarray(input_image)
113
-
114
  image_np = np.array(image)
115
  if len(image_np.shape) == 2: # grayscale
116
  image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
117
  elif image_np.shape[2] == 4: # RGBA
118
  image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
119
-
120
- if detection_mode == "Simple Check":
121
- result = nsfw_model.predict(image)
122
- return result, None
123
- else: # Detailed Analysis
124
- image_np = resize_and_pad(image_np, imgsz) # 여기서 imgsz는 (640, 640)
125
- processed_image = process_image_yolo(image_np, conf_threshold, iou_threshold, label_mode)
126
- return "Detailed analysis completed. See the image for results.", processed_image
 
 
 
 
6
  import torch
7
  import gradio as gr
8
 
9
+ from models.common import DetectMultiBackend # , NSFWModel
10
  from utils.torch_utils import select_device
11
+ from utils.general import check_img_size, non_max_suppression, scale_boxes
12
  from utils.plots import Annotator, colors
13
  from config.settings import DETECT_MODEL_PATH
14
 
15
+ # # Load classification model
16
+ # nsfw_model = NSFWModel()
17
 
18
  # Load YOLO model
19
+ device = select_device("")
20
  yolo_model = DetectMultiBackend(DETECT_MODEL_PATH, device=device, dnn=False, data=None, fp16=False)
21
  stride, names, pt = yolo_model.stride, yolo_model.names, yolo_model.pt
22
  imgsz = check_img_size((640, 640), s=stride)
23
 
24
+
25
  def resize_and_pad(image, target_size):
26
  ih, iw = image.shape[:2]
27
  target_h, target_w = target_size
28
 
29
  # 이미지의 가로세로 비율 계산
30
+ scale = min(target_h / ih, target_w / iw)
31
+
32
  # 새로운 크기 계산
33
  new_h, new_w = int(ih * scale), int(iw * scale)
34
+
35
  # 이미지 리사이즈
36
  resized = cv2.resize(image, (new_w, new_h))
37
+
38
  # 패딩 계산
39
  pad_h = (target_h - new_h) // 2
40
  pad_w = (target_w - new_w) // 2
41
+
42
  # 패딩 추가
43
+ padded = cv2.copyMakeBorder(
44
+ resized,
45
+ pad_h,
46
+ target_h - new_h - pad_h,
47
+ pad_w,
48
+ target_w - new_w - pad_w,
49
+ cv2.BORDER_CONSTANT,
50
+ value=[0, 0, 0],
51
+ )
52
+
53
  return padded
54
 
55
+
56
  def process_image_yolo(image, conf_threshold, iou_threshold, label_mode):
57
  # Image preprocessing
58
  im = torch.from_numpy(image).to(device).permute(2, 0, 1)
 
60
  im /= 255
61
  if len(im.shape) == 3:
62
  im = im[None]
63
+
64
  # Resize image
65
+ im = torch.nn.functional.interpolate(im, size=imgsz, mode="bilinear", align_corners=False)
66
+
67
  # Inference
68
  pred = yolo_model(im, augment=False, visualize=False)
69
  if isinstance(pred, list):
70
  pred = pred[0]
71
+
72
  # NMS
73
  pred = non_max_suppression(pred, conf_threshold, iou_threshold, None, False, max_det=1000)
74
 
 
76
  img = image.copy()
77
  harmful_label_list = []
78
  annotations = []
79
+
80
  for i, det in enumerate(pred):
81
  if len(det):
82
  det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], img.shape).round()
83
+
84
  for *xyxy, conf, cls in reversed(det):
85
  c = int(cls)
86
  if c != 6:
87
  harmful_label_list.append(c)
88
+
89
  annotation = {
90
+ "xyxy": xyxy,
91
+ "conf": conf,
92
+ "cls": c,
93
+ "label": (
94
+ f"{names[c]} {conf:.2f}"
95
+ if label_mode == "Draw Confidence"
96
+ else f"{names[c]}"
97
+ ),
98
  }
99
  annotations.append(annotation)
100
+
101
  if 4 in harmful_label_list and 10 in harmful_label_list:
102
  gr.Warning("Warning: This image is featuring underwear.")
103
  elif harmful_label_list:
104
  gr.Error("Warning: This image may contain harmful content.")
105
  img = cv2.GaussianBlur(img, (125, 125), 0)
106
  else:
107
+ gr.Info("This image appears to be safe.")
108
+
109
  annotator = Annotator(img, line_width=3, example=str(names))
110
+
111
  for ann in annotations:
112
  if label_mode == "Draw box":
113
+ annotator.box_label(ann["xyxy"], None, color=colors(ann["cls"], True))
114
  elif label_mode in ["Draw Label", "Draw Confidence"]:
115
+ annotator.box_label(ann["xyxy"], ann["label"], color=colors(ann["cls"], True))
116
  elif label_mode == "Censor Predictions":
117
+ cv2.rectangle(
118
+ img,
119
+ (int(ann["xyxy"][0]), int(ann["xyxy"][1])),
120
+ (int(ann["xyxy"][2]), int(ann["xyxy"][3])),
121
+ (0, 0, 0),
122
+ -1,
123
+ )
124
 
125
  return annotator.result()
126
 
127
+
128
+ def detect_nsfw(input_image, conf_threshold=0.3, iou_threshold=0.45, label_mode="Draw box"):
129
  if isinstance(input_image, str): # URL input
130
  response = requests.get(input_image)
131
  image = Image.open(BytesIO(response.content))
132
  else: # File upload
133
  image = Image.fromarray(input_image)
134
+
135
  image_np = np.array(image)
136
  if len(image_np.shape) == 2: # grayscale
137
  image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
138
  elif image_np.shape[2] == 4: # RGBA
139
  image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
140
+
141
+ # if detection_mode == "Simple Check":
142
+ # result = nsfw_model.predict(image)
143
+ # return result, None
144
+ # else: # Detailed Analysis
145
+ # image_np = resize_and_pad(image_np, imgsz) # 여기서 imgsz는 (640, 640)
146
+ # processed_image = process_image_yolo(image_np, conf_threshold, iou_threshold, label_mode)
147
+ # return "Detailed analysis completed. See the image for results.", processed_image
148
+ image_np = resize_and_pad(image_np, imgsz) # 여기서 imgsz는 (640, 640)
149
+ processed_image = process_image_yolo(image_np, conf_threshold, iou_threshold, label_mode)
150
+ return "Detailed analysis completed. See the image for results.", processed_image