Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
# @Author : wenshao | |
# @Email : [email protected] | |
# @Project : FasterLivePortrait | |
# @FileName: warping_spade_model.py | |
import pdb | |
import numpy as np | |
from .base_model import BaseModel | |
import torch | |
from torch.cuda import nvtx | |
from .predictor import numpy_to_torch_dtype_dict | |
class WarpingSpadeModel(BaseModel): | |
""" | |
WarpingSpade Model | |
""" | |
def __init__(self, **kwargs): | |
super(WarpingSpadeModel, self).__init__(**kwargs) | |
def input_process(self, *data): | |
feature_3d, kp_source, kp_driving = data | |
return feature_3d, kp_driving, kp_source | |
def output_process(self, *data): | |
if self.predict_type != "trt": | |
out = torch.from_numpy(data[0]).to(self.device).float() | |
else: | |
out = data[0] | |
out = out.permute(0, 2, 3, 1) | |
out = torch.clip(out, 0, 1) * 255 | |
return out[0] | |
def predict_trt(self, *data): | |
nvtx.range_push("forward") | |
feed_dict = {} | |
for i, inp in enumerate(self.predictor.inputs): | |
if isinstance(data[i], torch.Tensor): | |
feed_dict[inp['name']] = data[i] | |
else: | |
feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device, | |
dtype=numpy_to_torch_dtype_dict[inp['dtype']]) | |
preds_dict = self.predictor.predict(feed_dict, self.cudaStream) | |
outs = [] | |
for i, out in enumerate(self.predictor.outputs): | |
outs.append(preds_dict[out["name"]].clone()) | |
nvtx.range_pop() | |
return outs | |
def predict(self, *data): | |
data = self.input_process(*data) | |
if self.predict_type == "trt": | |
preds = self.predict_trt(*data) | |
else: | |
preds = self.predictor.predict(*data) | |
outputs = self.output_process(*preds) | |
return outputs | |