LearningnRunning commited on
Commit
83413d2
·
1 Parent(s): 1871c9c

Update data_processing.py

Browse files
Files changed (1) hide show
  1. utils/data_processing.py +21 -13
utils/data_processing.py CHANGED
@@ -76,6 +76,7 @@ def process_image_yolo(image, conf_threshold, iou_threshold, label_mode):
76
  img = image.copy()
77
  harmful_label_list = []
78
  annotations = []
 
79
 
80
  for i, det in enumerate(pred):
81
  if len(det):
@@ -83,8 +84,9 @@ def process_image_yolo(image, conf_threshold, iou_threshold, label_mode):
83
 
84
  for *xyxy, conf, cls in reversed(det):
85
  c = int(cls)
86
- if c != 6:
87
- harmful_label_list.append(c)
 
88
 
89
  annotation = {
90
  "xyxy": xyxy,
@@ -98,14 +100,7 @@ def process_image_yolo(image, conf_threshold, iou_threshold, label_mode):
98
  }
99
  annotations.append(annotation)
100
 
101
- if 4 in harmful_label_list and 10 in harmful_label_list:
102
- gr.Warning("Warning: This image is featuring underwear.")
103
- elif harmful_label_list:
104
- gr.Error("Warning: This image may contain harmful content.")
105
- img = cv2.GaussianBlur(img, (125, 125), 0)
106
- else:
107
- gr.Info("This image appears to be safe.")
108
-
109
  annotator = Annotator(img, line_width=3, example=str(names))
110
 
111
  for ann in annotations:
@@ -122,7 +117,7 @@ def process_image_yolo(image, conf_threshold, iou_threshold, label_mode):
122
  -1,
123
  )
124
 
125
- return annotator.result()
126
 
127
 
128
  def detect_nsfw(input_image, conf_threshold=0.3, iou_threshold=0.45, label_mode="Draw box"):
@@ -145,6 +140,19 @@ def detect_nsfw(input_image, conf_threshold=0.3, iou_threshold=0.45, label_mode=
145
  # image_np = resize_and_pad(image_np, imgsz) # 여기서 imgsz는 (640, 640)
146
  # processed_image = process_image_yolo(image_np, conf_threshold, iou_threshold, label_mode)
147
  # return "Detailed analysis completed. See the image for results.", processed_image
 
 
148
  image_np = resize_and_pad(image_np, imgsz) # 여기서 imgsz는 (640, 640)
149
- processed_image = process_image_yolo(image_np, conf_threshold, iou_threshold, label_mode)
150
- return "Detailed analysis completed. See the image for results.", processed_image
 
 
 
 
 
 
 
 
 
 
 
 
76
  img = image.copy()
77
  harmful_label_list = []
78
  annotations = []
79
+ label_counts = {}
80
 
81
  for i, det in enumerate(pred):
82
  if len(det):
 
84
 
85
  for *xyxy, conf, cls in reversed(det):
86
  c = int(cls)
87
+ harmful_label_list.append(c)
88
+ label_name = names[c]
89
+ label_counts[label_name] = label_counts.get(label_name, 0) + 1
90
 
91
  annotation = {
92
  "xyxy": xyxy,
 
100
  }
101
  annotations.append(annotation)
102
 
103
+ # Annotate image
 
 
 
 
 
 
 
104
  annotator = Annotator(img, line_width=3, example=str(names))
105
 
106
  for ann in annotations:
 
117
  -1,
118
  )
119
 
120
+ return annotator.result(), label_counts
121
 
122
 
123
  def detect_nsfw(input_image, conf_threshold=0.3, iou_threshold=0.45, label_mode="Draw box"):
 
140
  # image_np = resize_and_pad(image_np, imgsz) # 여기서 imgsz는 (640, 640)
141
  # processed_image = process_image_yolo(image_np, conf_threshold, iou_threshold, label_mode)
142
  # return "Detailed analysis completed. See the image for results.", processed_image
143
+
144
+ # Resize image and process with YOLO
145
  image_np = resize_and_pad(image_np, imgsz) # 여기서 imgsz는 (640, 640)
146
+ processed_image, label_counts = process_image_yolo(
147
+ image_np, conf_threshold, iou_threshold, label_mode
148
+ )
149
+
150
+ # Construct detailed result text
151
+ if label_counts:
152
+ result_text = "Detected content:\n"
153
+ for label, count in label_counts.items():
154
+ result_text += f" - {label}: {count} instance(s)\n"
155
+ else:
156
+ result_text = "This image appears to be safe."
157
+
158
+ return result_text, processed_image