Spaces:
Running
on
Zero
Running
on
Zero
################################################# | |
# Copyright (c) 2021-present, xiaobing.ai, Inc. # | |
# All rights reserved. # | |
################################################# | |
# CV Research, DEV(USA) xiaobing. # | |
# written by [email protected] # | |
################################################# | |
##### python internal and external package | |
import os | |
import cv2 | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
##### self defined package | |
from lib.models.fd.models.retinaface import RetinaFace | |
from lib.models.fd.layers.functions.prior_box import PriorBox | |
from lib.models.fd.utils.box_utils import decode, decode_landm | |
from lib.models.fd.utils.nms.py_cpu_nms import py_cpu_nms | |
# from lib.models.fd.config import cfg_re50 as cfg | |
def check_keys(model, pretrained_state_dict): | |
ckpt_keys = set(pretrained_state_dict.keys()) | |
model_keys = set(model.state_dict().keys()) | |
used_pretrained_keys = model_keys & ckpt_keys | |
unused_pretrained_keys = ckpt_keys - model_keys | |
missing_keys = model_keys - ckpt_keys | |
print('Missing keys:{}'.format(len(missing_keys))) | |
print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys))) | |
print('Used keys:{}'.format(len(used_pretrained_keys))) | |
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' | |
return True | |
def remove_prefix(state_dict, prefix): | |
''' Old style model is stored with all names of parameters sharing common prefix 'module.' ''' | |
print('remove prefix \'{}\''.format(prefix)) | |
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x | |
return {f(key): value for key, value in state_dict.items()} | |
def load_model(model, pretrained_path, load_to_cpu): | |
print('Loading pretrained model from {}'.format(pretrained_path)) | |
if load_to_cpu: | |
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage) | |
else: | |
device = torch.cuda.current_device() | |
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) | |
if "state_dict" in pretrained_dict.keys(): | |
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.') | |
else: | |
pretrained_dict = remove_prefix(pretrained_dict, 'module.') | |
check_keys(model, pretrained_dict) | |
model.load_state_dict(pretrained_dict, strict=False) | |
return model | |
class faceDetector(nn.Module): | |
def __init__(self, cfg): | |
super(faceDetector, self).__init__() | |
if cfg.model.fd.model_name == "retinaface": | |
self.model = RetinaFace(cfg=cfg.model.fd.config, phase='test') | |
self.model = load_model(self.model, cfg.model.fd.model_path, False) | |
else: | |
print("Error: the model {} of face detect is not exists".format(cfg.model.ldmk.model_name)) | |
self.model.eval() | |
self.model.cuda() | |
self.resize_h = cfg.model.fd.img_size | |
self.confidence_threshold = cfg.model.fd.thres # 0.8 | |
self.nms_threshold = cfg.model.fd.nms_thres # 0.4 | |
self.cfg = cfg.model.fd.config | |
def _transform(self, img): | |
h, w, c = img.shape | |
ratio = self.resize_h / h | |
img = cv2.resize(img, None, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR) | |
img = img.astype(np.float32) | |
img -= (104., 117., 123.) | |
img = img.transpose(2, 0, 1) | |
img = torch.from_numpy(img).unsqueeze(0) | |
img = img.cuda() | |
return img, ratio | |
def forward(self, img): | |
img, ratio= self._transform(img) | |
loc, conf, ldmks = self.model(img) | |
_, _, h, w = img.shape | |
priorbox = PriorBox(self.cfg, image_size=(h, w)) | |
priors = priorbox.forward() | |
priors = priors.cuda() | |
prior_data = priors.data | |
boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) | |
boxes = boxes * torch.Tensor([w, h, w, h]).cuda() / ratio | |
boxes = boxes.cpu().numpy() | |
scores = conf.squeeze(0).data.cpu().numpy()[:, 1] | |
inds = np.where(scores > self.confidence_threshold)[0] | |
boxes = boxes[inds] | |
scores = scores[inds] | |
# keep top-K before NMS | |
order = scores.argsort()[::-1] | |
# order = scores.argsort()[::-1][:args.top_k] | |
boxes = boxes[order] | |
scores = scores[order] | |
# do NMS | |
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) | |
keep = py_cpu_nms(dets, self.nms_threshold) | |
dets = dets[keep, :] | |
return dets |