import cv2 import sys sys.path.insert(0, "FaceBoxesV2") sys.path.insert(0, "..") from math import floor from faceboxes_detector import * import torch import torch.nn.parallel import torch.utils.data import torchvision.transforms as transforms import torchvision.models as models from networks import * from functions import * from PIPNet.reverse_index import ri1, ri2 class Config: def __init__(self): self.det_head = "pip" self.net_stride = 32 self.batch_size = 16 self.init_lr = 0.0001 self.num_epochs = 60 self.decay_steps = [30, 50] self.input_size = 256 self.backbone = "resnet101" self.pretrained = True self.criterion_cls = "l2" self.criterion_reg = "l1" self.cls_loss_weight = 10 self.reg_loss_weight = 1 self.num_lms = 98 self.save_interval = self.num_epochs self.num_nb = 10 self.use_gpu = True self.gpu_id = 3 def get_lmk_model(): cfg = Config() resnet101 = models.resnet101(pretrained=cfg.pretrained) net = Pip_resnet101( resnet101, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride, ) if cfg.use_gpu: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") else: device = torch.device("cpu") net = net.to(device) weight_file = "/apdcephfs/share_1290939/ahbanliang/codes/PIPNet/snapshots/WFLW/pip_32_16_60_r101_l2_l1_10_1_nb10/epoch59.pth" state_dict = torch.load(weight_file, map_location=device) net.load_state_dict(state_dict) detector = FaceBoxesDetector( "FaceBoxes", "FaceBoxesV2/weights/FaceBoxesV2.pth", use_gpu=True, device="cuda:0", ) return net, detector def demo_image( image_file, net, detector, input_size=256, net_stride=32, num_nb=10, use_gpu=True, device="cuda:0", ): my_thresh = 0.6 det_box_scale = 1.2 net.eval() preprocess = transforms.Compose( [ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) reverse_index1, reverse_index2, max_len = ri1, ri2, 17 image = cv2.imread(image_file) image_height, image_width, _ = image.shape detections, _ = detector.detect(image, my_thresh, 1) for i in range(len(detections)): det_xmin = detections[i][2] det_ymin = detections[i][3] det_width = detections[i][4] det_height = detections[i][5] det_xmax = det_xmin + det_width - 1 det_ymax = det_ymin + det_height - 1 det_xmin -= int(det_width * (det_box_scale - 1) / 2) # remove a part of top area for alignment, see paper for details det_ymin += int(det_height * (det_box_scale - 1) / 2) det_xmax += int(det_width * (det_box_scale - 1) / 2) det_ymax += int(det_height * (det_box_scale - 1) / 2) det_xmin = max(det_xmin, 0) det_ymin = max(det_ymin, 0) det_xmax = min(det_xmax, image_width - 1) det_ymax = min(det_ymax, image_height - 1) det_width = det_xmax - det_xmin + 1 det_height = det_ymax - det_ymin + 1 cv2.rectangle(image, (det_xmin, det_ymin), (det_xmax, det_ymax), (0, 0, 255), 2) det_crop = image[det_ymin:det_ymax, det_xmin:det_xmax, :] det_crop = cv2.resize(det_crop, (input_size, input_size)) inputs = Image.fromarray(det_crop[:, :, ::-1].astype("uint8"), "RGB") inputs = preprocess(inputs).unsqueeze(0) inputs = inputs.to(device) ( lms_pred_x, lms_pred_y, lms_pred_nb_x, lms_pred_nb_y, outputs_cls, max_cls, ) = forward_pip(net, inputs, preprocess, input_size, net_stride, num_nb) lms_pred = torch.cat((lms_pred_x, lms_pred_y), dim=1).flatten() tmp_nb_x = lms_pred_nb_x[reverse_index1, reverse_index2].view(98, max_len) tmp_nb_y = lms_pred_nb_y[reverse_index1, reverse_index2].view(98, max_len) tmp_x = torch.mean(torch.cat((lms_pred_x, tmp_nb_x), dim=1), dim=1).view(-1, 1) tmp_y = torch.mean(torch.cat((lms_pred_y, tmp_nb_y), dim=1), dim=1).view(-1, 1) lms_pred_merge = torch.cat((tmp_x, tmp_y), dim=1).flatten() lms_pred = lms_pred.cpu().numpy() lms_pred_merge = lms_pred_merge.cpu().numpy() for i in range(98): x_pred = lms_pred_merge[i * 2] * det_width y_pred = lms_pred_merge[i * 2 + 1] * det_height cv2.circle( image, (int(x_pred) + det_xmin, int(y_pred) + det_ymin), 1, (0, 0, 255), 2, ) cv2.imwrite("images/1_out.jpg", image) if __name__ == "__main__": net, detector = get_lmk_model() demo_image( "/apdcephfs/private_ahbanliang/codes/Real-ESRGAN-master/tmp_frames/yanikefu/frame00000046.png", net, detector, )