# largely borrowed from https://github.dev/elliottzheng/batch-face/face_detection/alignment.py from typing import Dict, List, Optional, Tuple import torch import torch.backends.cudnn as cudnn import torch.nn as nn import torch.nn.functional as F import torchvision.models._utils as _utils from .base import FaceDetector from itertools import product as product from math import ceil pretrained_urls = { "mobilenet": "https://github.com/elliottzheng/face-detection/releases/download/0.0.1/mobilenet0.25_Final.pth", "resnet50": "https://github.com/elliottzheng/face-detection/releases/download/0.0.1/Resnet50_Final.pth" } def conv_bn(inp, oup, stride=1, leaky=0): return nn.Sequential( nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.LeakyReLU(negative_slope=leaky, inplace=True), ) def conv_bn_no_relu(inp, oup, stride): return nn.Sequential( nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), ) def conv_bn1X1(inp, oup, stride, leaky=0): return nn.Sequential( nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup), nn.LeakyReLU(negative_slope=leaky, inplace=True), ) def conv_dw(inp, oup, stride, leaky=0.1): return nn.Sequential( nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), nn.BatchNorm2d(inp), nn.LeakyReLU(negative_slope=leaky, inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.LeakyReLU(negative_slope=leaky, inplace=True), ) class SSH(nn.Module): def __init__(self, in_channel, out_channel): super(SSH, self).__init__() assert out_channel % 4 == 0 leaky = 0 if out_channel <= 64: leaky = 0.1 self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) self.conv5X5_1 = conv_bn( in_channel, out_channel // 4, stride=1, leaky=leaky) self.conv5X5_2 = conv_bn_no_relu( out_channel // 4, out_channel // 4, stride=1) self.conv7X7_2 = conv_bn( out_channel // 4, out_channel // 4, stride=1, leaky=leaky ) self.conv7x7_3 = conv_bn_no_relu( out_channel // 4, out_channel // 4, stride=1) def forward(self, input): conv3X3 = self.conv3X3(input) conv5X5_1 = self.conv5X5_1(input) conv5X5 = self.conv5X5_2(conv5X5_1) conv7X7_2 = self.conv7X7_2(conv5X5_1) conv7X7 = self.conv7x7_3(conv7X7_2) out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) out = F.relu(out) return out class FPN(nn.Module): def __init__(self, in_channels_list, out_channels): super(FPN, self).__init__() leaky = 0 if out_channels <= 64: leaky = 0.1 self.output1 = conv_bn1X1( in_channels_list[0], out_channels, stride=1, leaky=leaky ) self.output2 = conv_bn1X1( in_channels_list[1], out_channels, stride=1, leaky=leaky ) self.output3 = conv_bn1X1( in_channels_list[2], out_channels, stride=1, leaky=leaky ) self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) def forward(self, input): # names = list(input.keys()) input = list(input.values()) output1 = self.output1(input[0]) output2 = self.output2(input[1]) output3 = self.output3(input[2]) up3 = F.interpolate( output3, size=[output2.size(2), output2.size(3)], mode="nearest" ) output2 = output2 + up3 output2 = self.merge2(output2) up2 = F.interpolate( output2, size=[output1.size(2), output1.size(3)], mode="nearest" ) output1 = output1 + up2 output1 = self.merge1(output1) out = [output1, output2, output3] return out class MobileNetV1(nn.Module): def __init__(self): super(MobileNetV1, self).__init__() self.stage1 = nn.Sequential( conv_bn(3, 8, 2, leaky=0.1), # 3 conv_dw(8, 16, 1), # 7 conv_dw(16, 32, 2), # 11 conv_dw(32, 32, 1), # 19 conv_dw(32, 64, 2), # 27 conv_dw(64, 64, 1), # 43 ) self.stage2 = nn.Sequential( conv_dw(64, 128, 2), # 43 + 16 = 59 conv_dw(128, 128, 1), # 59 + 32 = 91 conv_dw(128, 128, 1), # 91 + 32 = 123 conv_dw(128, 128, 1), # 123 + 32 = 155 conv_dw(128, 128, 1), # 155 + 32 = 187 conv_dw(128, 128, 1), # 187 + 32 = 219 ) self.stage3 = nn.Sequential( conv_dw(128, 256, 2), # 219 +3 2 = 241 conv_dw(256, 256, 1), # 241 + 64 = 301 ) self.avg = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(256, 1000) def forward(self, x): x = self.stage1(x) x = self.stage2(x) x = self.stage3(x) x = self.avg(x) # x = self.model(x) x = x.view(-1, 256) x = self.fc(x) return x class ClassHead(nn.Module): def __init__(self, inchannels=512, num_anchors=3): super(ClassHead, self).__init__() self.num_anchors = num_anchors self.conv1x1 = nn.Conv2d( inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0 ) def forward(self, x): out = self.conv1x1(x) out = out.permute(0, 2, 3, 1).contiguous() return out.view(out.shape[0], -1, 2) class BboxHead(nn.Module): def __init__(self, inchannels=512, num_anchors=3): super(BboxHead, self).__init__() self.conv1x1 = nn.Conv2d( inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0 ) def forward(self, x): out = self.conv1x1(x) out = out.permute(0, 2, 3, 1).contiguous() return out.view(out.shape[0], -1, 4) class LandmarkHead(nn.Module): def __init__(self, inchannels=512, num_anchors=3): super(LandmarkHead, self).__init__() self.conv1x1 = nn.Conv2d( inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0 ) def forward(self, x): out = self.conv1x1(x) out = out.permute(0, 2, 3, 1).contiguous() return out.view(out.shape[0], -1, 10) class RetinaFace(nn.Module): def __init__(self, cfg=None, phase="train"): """ :param cfg: Network related settings. :param phase: train or test. """ super(RetinaFace, self).__init__() self.phase = phase backbone = None if cfg["name"] == "mobilenet0.25": backbone = MobileNetV1() elif cfg["name"] == "Resnet50": import torchvision.models as models backbone = models.resnet50(pretrained=cfg["pretrain"]) self.body = _utils.IntermediateLayerGetter( backbone, cfg["return_layers"]) in_channels_stage2 = cfg["in_channel"] in_channels_list = [ in_channels_stage2 * 2, in_channels_stage2 * 4, in_channels_stage2 * 8, ] out_channels = cfg["out_channel"] self.fpn = FPN(in_channels_list, out_channels) self.ssh1 = SSH(out_channels, out_channels) self.ssh2 = SSH(out_channels, out_channels) self.ssh3 = SSH(out_channels, out_channels) self.ClassHead = self._make_class_head( fpn_num=3, inchannels=cfg["out_channel"]) self.BboxHead = self._make_bbox_head( fpn_num=3, inchannels=cfg["out_channel"]) self.LandmarkHead = self._make_landmark_head( fpn_num=3, inchannels=cfg["out_channel"] ) def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2): classhead = nn.ModuleList() for i in range(fpn_num): classhead.append(ClassHead(inchannels, anchor_num)) return classhead def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2): bboxhead = nn.ModuleList() for i in range(fpn_num): bboxhead.append(BboxHead(inchannels, anchor_num)) return bboxhead def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2): landmarkhead = nn.ModuleList() for i in range(fpn_num): landmarkhead.append(LandmarkHead(inchannels, anchor_num)) return landmarkhead def forward(self, inputs): out = self.body(inputs) # FPN fpn = self.fpn(out) # SSH feature1 = self.ssh1(fpn[0]) feature2 = self.ssh2(fpn[1]) feature3 = self.ssh3(fpn[2]) features = [feature1, feature2, feature3] bbox_regressions = torch.cat( [self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1 ) classifications = torch.cat( [self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1 ) ldm_regressions = torch.cat( [self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1 ) if self.phase == "train": output = (bbox_regressions, classifications, ldm_regressions) else: output = ( bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions, ) return output # Adapted from https://github.com/Hakuyume/chainer-ssd def decode(loc: torch.Tensor, priors: torch.Tensor, variances: Tuple[float, float]) -> torch.Tensor: boxes = torch.cat( ( priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1]), ), 1, ) boxes[:, :2] -= boxes[:, 2:] / 2 boxes[:, 2:] += boxes[:, :2] return boxes def decode_landm(pre: torch.Tensor, priors: torch.Tensor, variances: Tuple[float, float]) -> torch.Tensor: landms = torch.cat( ( priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], ), dim=1, ) return landms def nms(dets: torch.Tensor, thresh: float) -> List[int]: """Pure Python NMS baseline.""" x1 = dets[:, 0] y1 = dets[:, 1] x2 = dets[:, 2] y2 = dets[:, 3] scores = dets[:, 4] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = torch.flip(scores.argsort(), [0]) keep = [] while order.numel() > 0: i = order[0].item() keep.append(i) xx1 = torch.maximum(x1[i], x1[order[1:]]) yy1 = torch.maximum(y1[i], y1[order[1:]]) xx2 = torch.minimum(x2[i], x2[order[1:]]) yy2 = torch.minimum(y2[i], y2[order[1:]]) w = torch.maximum(torch.tensor(0.0).to(dets), xx2 - xx1 + 1) h = torch.maximum(torch.tensor(0.0).to(dets), yy2 - yy1 + 1) inter = w * h ovr = inter / (areas[i] + areas[order[1:]] - inter) inds = torch.where(ovr <= thresh)[0] order = order[inds + 1] return keep class PriorBox: def __init__(self, cfg: dict, image_size: Tuple[int, int]): self.min_sizes = cfg["min_sizes"] self.steps = cfg["steps"] self.clip = cfg["clip"] self.image_size = image_size self.feature_maps = [ [ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps ] def generate_anchors(self, device) -> torch.Tensor: anchors = [] for k, f in enumerate(self.feature_maps): min_sizes = self.min_sizes[k] for i, j in product(range(f[0]), range(f[1])): for min_size in min_sizes: s_kx = min_size / self.image_size[1] s_ky = min_size / self.image_size[0] dense_cx = [ x * self.steps[k] / self.image_size[1] for x in [j + 0.5] ] dense_cy = [ y * self.steps[k] / self.image_size[0] for y in [i + 0.5] ] for cy, cx in product(dense_cy, dense_cx): anchors += [cx, cy, s_kx, s_ky] # back to torch land output = torch.tensor(anchors).view(-1, 4) if self.clip: output.clamp_(max=1, min=0) return output.to(device=device) cfg_mnet = { "name": "mobilenet0.25", "min_sizes": [[16, 32], [64, 128], [256, 512]], "steps": [8, 16, 32], "variance": [0.1, 0.2], "clip": False, "loc_weight": 2.0, "gpu_train": True, "batch_size": 32, "ngpu": 1, "epoch": 250, "decay1": 190, "decay2": 220, "image_size": 640, "pretrain": True, "return_layers": {"stage1": 1, "stage2": 2, "stage3": 3}, "in_channel": 32, "out_channel": 64, } cfg_re50 = { "name": "Resnet50", "min_sizes": [[16, 32], [64, 128], [256, 512]], "steps": [8, 16, 32], "variance": [0.1, 0.2], "clip": False, "loc_weight": 2.0, "gpu_train": True, "batch_size": 24, "ngpu": 4, "epoch": 100, "decay1": 70, "decay2": 90, "image_size": 840, "pretrain": False, "return_layers": {"layer2": 1, "layer3": 2, "layer4": 3}, "in_channel": 256, "out_channel": 256, } def check_keys(model, pretrained_state_dict): ckpt_keys = set(pretrained_state_dict.keys()) model_keys = set(model.state_dict().keys()) used_pretrained_keys = model_keys & ckpt_keys assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint" return True def remove_prefix(state_dict, prefix): """ Old style model is stored with all names of parameters sharing common prefix 'module.' """ def f(x): return x.split(prefix, 1)[-1] if x.startswith(prefix) else x return {f(key): value for key, value in state_dict.items()} def load_model(model, pretrained_path, load_to_cpu, network: str): if pretrained_path is None: url = pretrained_urls[network] if load_to_cpu: pretrained_dict = torch.utils.model_zoo.load_url( url, map_location=lambda storage, loc: storage ) else: pretrained_dict = torch.utils.model_zoo.load_url( url, map_location=lambda storage, loc: storage.cuda(device) ) else: if load_to_cpu: pretrained_dict = torch.load( pretrained_path, map_location=lambda storage, loc: storage ) else: device = torch.cuda.current_device() pretrained_dict = torch.load( pretrained_path, map_location=lambda storage, loc: storage.cuda( device) ) if "state_dict" in pretrained_dict.keys(): pretrained_dict = remove_prefix( pretrained_dict["state_dict"], "module.") else: pretrained_dict = remove_prefix(pretrained_dict, "module.") check_keys(model, pretrained_dict) model.load_state_dict(pretrained_dict, strict=False) return model def load_net(model_path, network="mobilenet"): if network == "mobilenet": cfg = cfg_mnet elif network == "resnet50": cfg = cfg_re50 else: raise NotImplementedError(network) # net and model net = RetinaFace(cfg=cfg, phase="test") net = load_model(net, model_path, True, network=network) net.eval() cudnn.benchmark = True # net = net.to(device) return net def parse_det(det: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, float]: landmarks = det[5:].reshape(5, 2) box = det[:4] score = det[4] return box, landmarks, score.item() def post_process( loc: torch.Tensor, conf: torch.Tensor, landms: torch.Tensor, prior_data: torch.Tensor, cfg: dict, scale: float, scale1: float, resize, confidence_threshold, top_k, nms_threshold, keep_top_k, ): boxes = decode(loc, prior_data, cfg["variance"]) boxes = boxes * scale / resize # boxes = boxes.cpu().numpy() # scores = conf.cpu().numpy()[:, 1] scores = conf[:, 1] landms_copy = decode_landm(landms, prior_data, cfg["variance"]) landms_copy = landms_copy * scale1 / resize # landms_copy = landms_copy.cpu().numpy() # ignore low scores inds = torch.where(scores > confidence_threshold)[0] boxes = boxes[inds] landms_copy = landms_copy[inds] scores = scores[inds] # keep top-K before NMS order = torch.flip(scores.argsort(), [0])[:top_k] boxes = boxes[order] landms_copy = landms_copy[order] scores = scores[order] # do NMS dets = torch.hstack((boxes, scores.unsqueeze(-1))).to( dtype=torch.float32, copy=False) keep = nms(dets, nms_threshold) # keep = nms(dets, args.nms_threshold,force_cpu=args.cpu) dets = dets[keep, :] landms_copy = landms_copy[keep] # keep top-K faster NMS dets = dets[:keep_top_k, :] landms_copy = landms_copy[:keep_top_k, :] dets = torch.cat((dets, landms_copy), dim=1) # show image dets = sorted(dets, key=lambda x: x[4], reverse=True) dets = [parse_det(x) for x in dets] return dets # @torch.no_grad() def batch_detect(net: nn.Module, images: torch.Tensor, threshold: float = 0.5): confidence_threshold = threshold cfg = cfg_mnet top_k = 5000 nms_threshold = 0.4 keep_top_k = 1 resize = 1 img = images.float() mean = torch.as_tensor([104, 117, 123], dtype=img.dtype, device=img.device).view( 1, 3, 1, 1 ) img -= mean ( _, _, im_height, im_width, ) = img.shape scale = torch.as_tensor( [im_width, im_height, im_width, im_height], dtype=img.dtype, device=img.device, ) scale = scale.to(img.device) loc, conf, landms = net(img) # forward pass priorbox = PriorBox(cfg, image_size=(im_height, im_width)) prior_data = priorbox.generate_anchors(device=img.device) scale1 = torch.as_tensor( [ img.shape[3], img.shape[2], img.shape[3], img.shape[2], img.shape[3], img.shape[2], img.shape[3], img.shape[2], img.shape[3], img.shape[2], ], dtype=img.dtype, device=img.device, ) scale1 = scale1.to(img.device) all_dets = [ post_process( loc_i, conf_i, landms_i, prior_data, cfg, scale, scale1, resize, confidence_threshold, top_k, nms_threshold, keep_top_k, ) for loc_i, conf_i, landms_i in zip(loc, conf, landms) ] rects = [] points = [] scores = [] image_ids = [] for image_id, faces_in_one_image in enumerate(all_dets): for rect, landmarks, score in faces_in_one_image: rects.append(rect) points.append(landmarks) scores.append(score) image_ids.append(image_id) if len(rects) == 0: return dict() return { 'rects': torch.stack(rects, dim=0).to(img.device), 'points': torch.stack(points, dim=0).to(img.device), 'scores': torch.tensor(scores).to(img.device), 'image_ids': torch.tensor(image_ids).to(img.device), } class RetinaFaceDetector(FaceDetector): """RetinaFaceDetector Args: images (torch.Tensor): b x c x h x w, uint8, 0~255. Returns: faces (Dict[str, torch.Tensor]): * image_ids: n, int * rects: n x 4 (x1, y1, x2, y2) * points: n x 5 x 2 (x, y) * scores: n """ def __init__(self, conf_name: Optional[str] = None, model_path: Optional[str] = None, threshold=0.8) -> None: super().__init__() if conf_name is None: conf_name = 'mobilenet' self.net = load_net(model_path, conf_name) self.threshold = threshold self.eval() def forward(self, images: torch.Tensor) -> Dict[str, torch.Tensor]: return batch_detect(self.net, images.clone(), threshold=self.threshold)