|
from fastapi import File, UploadFile, Form |
|
from fastapi.responses import JSONResponse |
|
from PIL import Image |
|
import io |
|
import numpy as np |
|
import torch |
|
|
|
from .general import non_max_suppression, scale_boxes |
|
from models.common import DetectMultiBackend |
|
from .torch_utils import select_device |
|
|
|
|
|
device = select_device('') |
|
model = DetectMultiBackend('weights/nsfw_detector_e_rok.pt', device=device, dnn=False, data='data/coco128.yaml', fp16=False) |
|
names = model.names |
|
imgsz = (640, 640) |
|
|
|
async def process_image_api( |
|
file: UploadFile = File(...), |
|
conf_threshold: float = Form(0.25), |
|
iou_threshold: float = Form(0.45), |
|
label_mode: str = Form("Draw Confidence") |
|
): |
|
contents = await file.read() |
|
image = Image.open(io.BytesIO(contents)) |
|
image_np = np.array(image) |
|
|
|
result = process_image(image_np, conf_threshold, iou_threshold, label_mode) |
|
|
|
return JSONResponse(content={"result": result.result}) |
|
|
|
def process_image(image, conf_threshold, iou_threshold, label_mode): |
|
|
|
im = torch.from_numpy(image).to(device).permute(2, 0, 1) |
|
im = im.half() if model.fp16 else im.float() |
|
im /= 255 |
|
if len(im.shape) == 3: |
|
im = im[None] |
|
|
|
|
|
im = torch.nn.functional.interpolate(im, size=imgsz, mode='bilinear', align_corners=False) |
|
|
|
|
|
pred = model(im, augment=False, visualize=False) |
|
if isinstance(pred, list): |
|
pred = pred[0] |
|
|
|
|
|
pred = non_max_suppression(pred, conf_threshold, iou_threshold, None, False, max_det=1000) |
|
|
|
|
|
img = image.copy() |
|
|
|
harmful_label_list = [] |
|
annotations = [] |
|
|
|
for i, det in enumerate(pred): |
|
if len(det): |
|
|
|
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], img.shape).round() |
|
|
|
|
|
for *xyxy, conf, cls in reversed(det): |
|
c = int(cls) |
|
if c != 6: |
|
harmful_label_list.append(c) |
|
|
|
annotation = { |
|
'xyxy': xyxy, |
|
'conf': float(conf), |
|
'cls': c, |
|
'label': f"{names[c]} {conf:.2f}" if label_mode == "Draw Confidence" else f"{names[c]}" |
|
} |
|
annotations.append(annotation) |
|
|
|
result = 'nsfw' if harmful_label_list else 'nomal' |
|
return ProcessResponse(result=result) |
|
|
|
class ProcessResponse: |
|
def __init__(self, result: int): |
|
self.result = result |
|
|