KDTalker / difpoint /src /models /motion_extractor_model.py
ChaolongYang's picture
Upload 242 files
475d332 verified
# -*- coding: utf-8 -*-
# @Author : wenshao
# @Email : [email protected]
# @Project : FasterLivePortrait
# @FileName: motion_extractor_model.py
import pdb
import numpy as np
from .base_model import BaseModel
import torch
from torch.cuda import nvtx
from .predictor import numpy_to_torch_dtype_dict
import torch.nn.functional as F
def headpose_pred_to_degree(pred):
"""
pred: (bs, 66) or (bs, 1) or others
"""
if pred.ndim > 1 and pred.shape[1] == 66:
# NOTE: note that the average is modified to 97.5
idx_array = np.arange(0, 66)
pred = np.apply_along_axis(lambda x: np.exp(x) / np.sum(np.exp(x)), 1, pred)
degree = np.sum(pred * idx_array, axis=1) * 3 - 97.5
return degree
return pred
class MotionExtractorModel(BaseModel):
"""
MotionExtractorModel
"""
def __init__(self, **kwargs):
super(MotionExtractorModel, self).__init__(**kwargs)
self.flag_refine_info = kwargs.get("flag_refine_info", True)
def input_process(self, *data):
img = data[0].astype(np.float32)
img /= 255.0
img = np.transpose(img, (2, 0, 1))
return img[None]
def output_process(self, *data):
if self.predict_type == "trt":
kp, pitch, yaw, roll, t, exp, scale = data
else:
pitch, yaw, roll, t, exp, scale, kp = data
if self.flag_refine_info:
bs = kp.shape[0]
pitch = headpose_pred_to_degree(pitch)[:, None] # Bx1
yaw = headpose_pred_to_degree(yaw)[:, None] # Bx1
roll = headpose_pred_to_degree(roll)[:, None] # Bx1
kp = kp.reshape(bs, -1, 3) # BxNx3
exp = exp.reshape(bs, -1, 3) # BxNx3
return pitch, yaw, roll, t, exp, scale, kp
def predict_trt(self, *data):
nvtx.range_push("forward")
feed_dict = {}
for i, inp in enumerate(self.predictor.inputs):
if isinstance(data[i], torch.Tensor):
feed_dict[inp['name']] = data[i]
else:
feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device,
dtype=numpy_to_torch_dtype_dict[inp['dtype']])
preds_dict = self.predictor.predict(feed_dict, self.cudaStream)
outs = []
for i, out in enumerate(self.predictor.outputs):
outs.append(preds_dict[out["name"]].cpu().numpy())
nvtx.range_pop()
return outs
def predict(self, *data):
#img = self.input_process(*data)
img = data[0]
if self.predict_type == "trt":
preds = self.predict_trt(img)
else:
preds = self.predictor.predict(img)
outputs = self.output_process(*preds)
return outputs