KDTalker / difpoint /src /models /landmark_model.py
ChaolongYang's picture
Upload 242 files
475d332 verified
# -*- coding: utf-8 -*-
# @Author : wenshao
# @Email : [email protected]
# @Project : FasterLivePortrait
# @FileName: landmark_model.py
from .base_model import BaseModel
import cv2
import numpy as np
from difpoint.src.utils.crop import crop_image, _transform_pts
import torch
from torch.cuda import nvtx
from .predictor import numpy_to_torch_dtype_dict
class LandmarkModel(BaseModel):
"""
landmark Model
"""
def __init__(self, **kwargs):
super(LandmarkModel, self).__init__(**kwargs)
self.dsize = 224
def input_process(self, *data):
if len(data) > 1:
img_rgb, lmk = data
else:
img_rgb = data[0]
lmk = None
if lmk is not None:
crop_dct = crop_image(img_rgb, lmk, dsize=self.dsize, scale=1.5, vy_ratio=-0.1)
img_crop_rgb = crop_dct['img_crop']
else:
# NOTE: force resize to 224x224, NOT RECOMMEND!
img_crop_rgb = cv2.resize(img_rgb, (self.dsize, self.dsize))
scale = max(img_rgb.shape[:2]) / self.dsize
crop_dct = {
'M_c2o': np.array([
[scale, 0., 0.],
[0., scale, 0.],
[0., 0., 1.],
], dtype=np.float32),
}
inp = (img_crop_rgb.astype(np.float32) / 255.).transpose(2, 0, 1)[None, ...] # HxWx3 (BGR) -> 1x3xHxW (RGB!)
return inp, crop_dct
def output_process(self, *data):
out_pts, crop_dct = data
lmk = out_pts[2].reshape(-1, 2) * self.dsize # scale to 0-224
lmk = _transform_pts(lmk, M=crop_dct['M_c2o'])
return lmk
def predict_trt(self, *data):
nvtx.range_push("forward")
feed_dict = {}
for i, inp in enumerate(self.predictor.inputs):
if isinstance(data[i], torch.Tensor):
feed_dict[inp['name']] = data[i]
else:
feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device,
dtype=numpy_to_torch_dtype_dict[inp['dtype']])
preds_dict = self.predictor.predict(feed_dict, self.cudaStream)
outs = []
for i, out in enumerate(self.predictor.outputs):
outs.append(preds_dict[out["name"]].cpu().numpy())
nvtx.range_pop()
return outs
def predict(self, *data):
input, crop_dct = self.input_process(*data)
if self.predict_type == "trt":
preds = self.predict_trt(input)
else:
preds = self.predictor.predict(input)
outputs = self.output_process(preds, crop_dct)
return outputs