adult_image_detector / utils /data_processing.py
LearningnRunning's picture
FEAT SImple version
9e31b92
raw
history blame
4.62 kB
import cv2
import numpy as np
from PIL import Image
import requests
from io import BytesIO
import torch
import gradio as gr
from models.common import DetectMultiBackend, NSFWModel
from utils.torch_utils import select_device
from utils.general import (check_img_size, non_max_suppression, scale_boxes)
from utils.plots import Annotator, colors
# Load classification model
nsfw_model = NSFWModel()
# Load YOLO model
device = select_device('')
yolo_model = DetectMultiBackend('./weights/nsfw_detector_e_rok.pt', device=device, dnn=False, data=None, fp16=False)
stride, names, pt = yolo_model.stride, yolo_model.names, yolo_model.pt
imgsz = check_img_size((640, 640), s=stride)
def resize_and_pad(image, target_size):
ih, iw = image.shape[:2]
target_h, target_w = target_size
# 이미지의 가로세로 비율 계산
scale = min(target_h/ih, target_w/iw)
# 새로운 크기 계산
new_h, new_w = int(ih * scale), int(iw * scale)
# 이미지 리사이즈
resized = cv2.resize(image, (new_w, new_h))
# 패딩 계산
pad_h = (target_h - new_h) // 2
pad_w = (target_w - new_w) // 2
# 패딩 추가
padded = cv2.copyMakeBorder(resized, pad_h, target_h-new_h-pad_h, pad_w, target_w-new_w-pad_w, cv2.BORDER_CONSTANT, value=[0,0,0])
return padded
def process_image_yolo(image, conf_threshold, iou_threshold, label_mode):
# Image preprocessing
im = torch.from_numpy(image).to(device).permute(2, 0, 1)
im = im.half() if yolo_model.fp16 else im.float()
im /= 255
if len(im.shape) == 3:
im = im[None]
# Resize image
im = torch.nn.functional.interpolate(im, size=imgsz, mode='bilinear', align_corners=False)
# Inference
pred = yolo_model(im, augment=False, visualize=False)
if isinstance(pred, list):
pred = pred[0]
# NMS
pred = non_max_suppression(pred, conf_threshold, iou_threshold, None, False, max_det=1000)
# Process results
img = image.copy()
harmful_label_list = []
annotations = []
for i, det in enumerate(pred):
if len(det):
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], img.shape).round()
for *xyxy, conf, cls in reversed(det):
c = int(cls)
if c != 6:
harmful_label_list.append(c)
annotation = {
'xyxy': xyxy,
'conf': conf,
'cls': c,
'label': f"{names[c]} {conf:.2f}" if label_mode == "Draw Confidence" else f"{names[c]}"
}
annotations.append(annotation)
if 4 in harmful_label_list and 10 in harmful_label_list:
gr.Warning("Warning: This image is featuring underwear.")
elif harmful_label_list:
gr.Error("Warning: This image may contain harmful content.")
img = cv2.GaussianBlur(img, (125, 125), 0)
else:
gr.Info('This image appears to be safe.')
annotator = Annotator(img, line_width=3, example=str(names))
for ann in annotations:
if label_mode == "Draw box":
annotator.box_label(ann['xyxy'], None, color=colors(ann['cls'], True))
elif label_mode in ["Draw Label", "Draw Confidence"]:
annotator.box_label(ann['xyxy'], ann['label'], color=colors(ann['cls'], True))
elif label_mode == "Censor Predictions":
cv2.rectangle(img, (int(ann['xyxy'][0]), int(ann['xyxy'][1])), (int(ann['xyxy'][2]), int(ann['xyxy'][3])), (0, 0, 0), -1)
return annotator.result()
def detect_nsfw(input_image, detection_mode, conf_threshold=0.3, iou_threshold=0.45, label_mode="Draw box"):
if isinstance(input_image, str): # URL input
response = requests.get(input_image)
image = Image.open(BytesIO(response.content))
else: # File upload
image = Image.fromarray(input_image)
image_np = np.array(image)
if len(image_np.shape) == 2: # grayscale
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
elif image_np.shape[2] == 4: # RGBA
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
if detection_mode == "Simple Check":
result = nsfw_model.predict(image)
return result, None
else: # Detailed Analysis
image_np = resize_and_pad(image_np, imgsz) # 여기서 imgsz는 (640, 640)
processed_image = process_image_yolo(image_np, conf_threshold, iou_threshold, label_mode)
return "Detailed analysis completed. See the image for results.", processed_image