Spaces:
Runtime error
Runtime error
Create helper_cpu.py
Browse files- src/utils/helper_cpu.py +173 -0
src/utils/helper_cpu.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Utility functions and classes to handle feature extraction and model loading
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import os.path as osp
|
| 9 |
+
import torch
|
| 10 |
+
from collections import OrderedDict
|
| 11 |
+
import psutil
|
| 12 |
+
from rich.console import Console
|
| 13 |
+
from rich.progress import Progress
|
| 14 |
+
from ..modules.spade_generator import SPADEDecoder
|
| 15 |
+
from ..modules.warping_network import WarpingNetwork
|
| 16 |
+
from ..modules.motion_extractor import MotionExtractor
|
| 17 |
+
from ..modules.appearance_feature_extractor import AppearanceFeatureExtractor
|
| 18 |
+
from ..modules.stitching_retargeting_network import StitchingRetargetingNetwork
|
| 19 |
+
|
| 20 |
+
from rich.console import Console
|
| 21 |
+
import psutil
|
| 22 |
+
|
| 23 |
+
console = Console()
|
| 24 |
+
|
| 25 |
+
def show_memory_usage():
|
| 26 |
+
"""
|
| 27 |
+
Display the current memory usage in the terminal using rich.
|
| 28 |
+
"""
|
| 29 |
+
mem_info = psutil.virtual_memory()
|
| 30 |
+
total_mem = mem_info.total / (1024 ** 3) # Convert to GB
|
| 31 |
+
used_mem = mem_info.used / (1024 ** 3) # Convert to GB
|
| 32 |
+
available_mem = mem_info.available / (1024 ** 3) # Convert to GB
|
| 33 |
+
|
| 34 |
+
console.log(f"[bold green]Memory Usage:[/bold green] [bold red]{used_mem:.2f} GB[/bold red] used of [bold blue]{total_mem:.2f} GB[/bold blue]")
|
| 35 |
+
console.log(f"[bold green]Available Memory:[/bold green] [bold yellow]{available_mem:.2f} GB[/bold yellow]")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def suffix(filename):
|
| 39 |
+
"""a.jpg -> jpg"""
|
| 40 |
+
pos = filename.rfind(".")
|
| 41 |
+
if pos == -1:
|
| 42 |
+
return ""
|
| 43 |
+
return filename[pos + 1:]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def prefix(filename):
|
| 47 |
+
"""a.jpg -> a"""
|
| 48 |
+
pos = filename.rfind(".")
|
| 49 |
+
if pos == -1:
|
| 50 |
+
return filename
|
| 51 |
+
return filename[:pos]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def basename(filename):
|
| 55 |
+
"""a/b/c.jpg -> c"""
|
| 56 |
+
return prefix(osp.basename(filename))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def is_video(file_path):
|
| 60 |
+
if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path):
|
| 61 |
+
return True
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def is_template(file_path):
|
| 66 |
+
if file_path.endswith(".pkl"):
|
| 67 |
+
return True
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def mkdir(d, log=False):
|
| 72 |
+
# return self-assigned `d`, for one line code
|
| 73 |
+
if not osp.exists(d):
|
| 74 |
+
os.makedirs(d, exist_ok=True)
|
| 75 |
+
if log:
|
| 76 |
+
log(f"Make dir: {d}")
|
| 77 |
+
return d
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def squeeze_tensor_to_numpy(tensor):
|
| 81 |
+
out = tensor.data.squeeze(0).cpu().numpy()
|
| 82 |
+
return out
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def dct2cpu(dct: dict, device='cpu'):
|
| 86 |
+
for key in dct:
|
| 87 |
+
dct[key] = torch.tensor(dct[key]).to(device)
|
| 88 |
+
return dct
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def concat_feat(kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
| 92 |
+
"""
|
| 93 |
+
kp_source: (bs, k, 3)
|
| 94 |
+
kp_driving: (bs, k, 3)
|
| 95 |
+
Return: (bs, 2k*3)
|
| 96 |
+
"""
|
| 97 |
+
bs_src = kp_source.shape[0]
|
| 98 |
+
bs_dri = kp_driving.shape[0]
|
| 99 |
+
assert bs_src == bs_dri, 'batch size must be equal'
|
| 100 |
+
|
| 101 |
+
feat = torch.cat([kp_source.view(bs_src, -1), kp_driving.view(bs_dri, -1)], dim=1)
|
| 102 |
+
return feat
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def remove_ddp_duplicate_key(state_dict):
|
| 106 |
+
state_dict_new = OrderedDict()
|
| 107 |
+
for key in state_dict.keys():
|
| 108 |
+
state_dict_new[key.replace('module.', '')] = state_dict[key]
|
| 109 |
+
return state_dict_new
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def load_model(ckpt_path, model_config, device, model_type):
|
| 113 |
+
model_params = model_config['model_params'][f'{model_type}_params']
|
| 114 |
+
|
| 115 |
+
if model_type == 'appearance_feature_extractor':
|
| 116 |
+
model = AppearanceFeatureExtractor(**model_params).to('cpu')
|
| 117 |
+
elif model_type == 'motion_extractor':
|
| 118 |
+
model = MotionExtractor(**model_params).to('cpu')
|
| 119 |
+
elif model_type == 'warping_module':
|
| 120 |
+
model = WarpingNetwork(**model_params).to('cpu')
|
| 121 |
+
elif model_type == 'spade_generator':
|
| 122 |
+
model = SPADEDecoder(**model_params).to('cpu')
|
| 123 |
+
elif model_type == 'stitching_retargeting_module':
|
| 124 |
+
# Special handling for stitching and retargeting module
|
| 125 |
+
config = model_config['model_params']['stitching_retargeting_module_params']
|
| 126 |
+
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
| 127 |
+
|
| 128 |
+
stitcher = StitchingRetargetingNetwork(**config.get('stitching'))
|
| 129 |
+
stitcher.load_state_dict(remove_ddp_duplicate_key(checkpoint['retarget_shoulder']))
|
| 130 |
+
stitcher = stitcher.to('cpu')
|
| 131 |
+
stitcher.eval()
|
| 132 |
+
|
| 133 |
+
retargetor_lip = StitchingRetargetingNetwork(**config.get('lip'))
|
| 134 |
+
retargetor_lip.load_state_dict(remove_ddp_duplicate_key(checkpoint['retarget_mouth']))
|
| 135 |
+
retargetor_lip = retargetor_lip.to('cpu')
|
| 136 |
+
retargetor_lip.eval()
|
| 137 |
+
|
| 138 |
+
retargetor_eye = StitchingRetargetingNetwork(**config.get('eye'))
|
| 139 |
+
retargetor_eye.load_state_dict(remove_ddp_duplicate_key(checkpoint['retarget_eye']))
|
| 140 |
+
retargetor_eye = retargetor_eye.to('cpu')
|
| 141 |
+
retargetor_eye.eval()
|
| 142 |
+
|
| 143 |
+
return {
|
| 144 |
+
'stitching': stitcher,
|
| 145 |
+
'lip': retargetor_lip,
|
| 146 |
+
'eye': retargetor_eye
|
| 147 |
+
}
|
| 148 |
+
else:
|
| 149 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
| 150 |
+
|
| 151 |
+
model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
|
| 152 |
+
model.eval()
|
| 153 |
+
return model
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# Get coefficients of Eqn. 7
|
| 157 |
+
def calculate_transformation(config, s_kp_info, t_0_kp_info, t_i_kp_info, R_s, R_t_0, R_t_i):
|
| 158 |
+
if config.relative:
|
| 159 |
+
new_rotation = (R_t_i @ R_t_0.permute(0, 2, 1)) @ R_s
|
| 160 |
+
new_expression = s_kp_info['exp'] + (t_i_kp_info['exp'] - t_0_kp_info['exp'])
|
| 161 |
+
else:
|
| 162 |
+
new_rotation = R_t_i
|
| 163 |
+
new_expression = t_i_kp_info['exp']
|
| 164 |
+
new_translation = s_kp_info['t'] + (t_i_kp_info['t'] - t_0_kp_info['t'])
|
| 165 |
+
new_translation[..., 2].fill_(0) # Keep the z-axis unchanged
|
| 166 |
+
new_scale = s_kp_info['scale'] * (t_i_kp_info['scale'] / t_0_kp_info['scale'])
|
| 167 |
+
return new_rotation, new_expression, new_translation, new_scale
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def load_description(fp):
|
| 171 |
+
with open(fp, 'r', encoding='utf-8') as f:
|
| 172 |
+
content = f.read()
|
| 173 |
+
return content
|