Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import json | |
import os | |
import time | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
import torch | |
import torchvision | |
import tyro | |
import yaml | |
from loguru import logger | |
from PIL import Image | |
from external.human_matting import StyleMatteEngine as HumanMattingEngine | |
from external.landmark_detection.FaceBoxesV2.faceboxes_detector import \ | |
FaceBoxesDetector | |
from external.landmark_detection.infer_image import Alignment | |
from external.vgghead_detector import VGGHeadDetector | |
from vhap.config.base import BaseTrackingConfig | |
from vhap.export_as_nerf_dataset import (NeRFDatasetWriter, | |
TrackedFLAMEDatasetWriter, split_json) | |
from vhap.model.tracker import GlobalTracker | |
# Define error codes for various processing failures. | |
ERROR_CODE = {'FailedToDetect': 1, 'FailedToOptimize': 2, 'FailedToExport': 3} | |
def expand_bbox(bbox, scale=1.1): | |
"""Expands the bounding box by a given scale.""" | |
xmin, ymin, xmax, ymax = bbox.unbind(dim=-1) | |
center_x, center_y = (xmin + xmax) / 2, (ymin + ymax) / 2 | |
extension_size = torch.sqrt((ymax - ymin) * (xmax - xmin)) * scale | |
x_min_expanded = center_x - extension_size / 2 | |
x_max_expanded = center_x + extension_size / 2 | |
y_min_expanded = center_y - extension_size / 2 | |
y_max_expanded = center_y + extension_size / 2 | |
return torch.stack( | |
[x_min_expanded, y_min_expanded, x_max_expanded, y_max_expanded], | |
dim=-1) | |
def load_config(src_folder: Path): | |
"""Load configuration from the given source folder.""" | |
config_file_path = src_folder / 'config.yml' | |
if not config_file_path.exists(): | |
src_folder = sorted( | |
src_folder.iterdir())[-1] # Get the last modified folder | |
config_file_path = src_folder / 'config.yml' | |
assert config_file_path.exists(), f'File not found: {config_file_path}' | |
config_data = yaml.load(config_file_path.read_text(), Loader=yaml.Loader) | |
return src_folder, config_data | |
class FlameTrackingSingleImage: | |
"""Class for tracking and processing a single image.""" | |
def __init__( | |
self, | |
output_dir, | |
alignment_model_path='./pretrain_model/68_keypoints_model.pkl', | |
vgghead_model_path='./pretrain_model/vgghead/vgg_heads_l.trcd', | |
human_matting_path='./pretrain_model/matting/stylematte_synth.pt', | |
facebox_model_path='./pretrain_model/FaceBoxesV2.pth', | |
detect_iris_landmarks=False): | |
logger.info(f'Output Directory: {output_dir}') | |
start_time = time.time() | |
logger.info('Loading Pre-trained Models...') | |
self.output_dir = output_dir | |
self.output_preprocess = os.path.join(output_dir, 'preprocess') | |
self.output_tracking = os.path.join(output_dir, 'tracking') | |
self.output_export = os.path.join(output_dir, 'export') | |
self.device = 'cuda:0' | |
# Load alignment model | |
assert os.path.exists( | |
alignment_model_path), f'{alignment_model_path} does not exist!' | |
args = self._parse_args() | |
args.model_path = alignment_model_path | |
self.alignment = Alignment(args, | |
alignment_model_path, | |
dl_framework='pytorch', | |
device_ids=[0]) | |
# Load VGG head model | |
assert os.path.exists( | |
vgghead_model_path), f'{vgghead_model_path} does not exist!' | |
self.vgghead_encoder = VGGHeadDetector( | |
device=self.device, vggheadmodel_path=vgghead_model_path) | |
# Load human matting model | |
assert os.path.exists( | |
human_matting_path), f'{human_matting_path} does not exist!' | |
self.matting_engine = HumanMattingEngine( | |
device=self.device, human_matting_path=human_matting_path) | |
# Load face box detector model | |
assert os.path.exists( | |
facebox_model_path), f'{facebox_model_path} does not exist!' | |
self.detector = FaceBoxesDetector('FaceBoxes', facebox_model_path, | |
True, self.device) | |
self.detect_iris_landmarks_flag = detect_iris_landmarks | |
if self.detect_iris_landmarks_flag: | |
from fdlite import FaceDetection, FaceLandmark, IrisLandmark | |
self.iris_detect_faces = FaceDetection() | |
self.iris_detect_face_landmarks = FaceLandmark() | |
self.iris_detect_iris_landmarks = IrisLandmark() | |
end_time = time.time() | |
torch.cuda.empty_cache() | |
logger.info(f'Finished Loading Pre-trained Models. Time: ' | |
f'{end_time - start_time:.2f}s') | |
def _parse_args(self): | |
parser = argparse.ArgumentParser(description='Evaluation script') | |
parser.add_argument('--output_dir', | |
type=str, | |
help='Output directory', | |
default='output') | |
parser.add_argument('--config_name', | |
type=str, | |
help='Configuration name', | |
default='alignment') | |
return parser.parse_args() | |
def preprocess(self, input_image_path): | |
"""Preprocess the input image for tracking.""" | |
if not os.path.exists(input_image_path): | |
logger.warning(f'{input_image_path} does not exist!') | |
return ERROR_CODE['FailedToDetect'] | |
start_time = time.time() | |
logger.info('Starting Preprocessing...') | |
name_list = [] | |
frame_index = 0 | |
# Bounding box detection | |
# frame = torchvision.io.read_image(input_image_path) | |
frame = cv2.imread(input_image_path)[:, :, ::-1].copy() | |
frame = torch.Tensor(frame).permute(2, 0, 1).contiguous()[:3, ...] | |
try: | |
_, frame_bbox, _ = self.vgghead_encoder(frame, frame_index) | |
except Exception: | |
logger.error('Failed to detect face') | |
return ERROR_CODE['FailedToDetect'] | |
if frame_bbox is None: | |
logger.error('Failed to detect face') | |
return ERROR_CODE['FailedToDetect'] | |
# Expand bounding box | |
name_list.append('00000.png') | |
frame_bbox = expand_bbox(frame_bbox, scale=1.65).long() | |
# Crop and resize | |
cropped_frame = torchvision.transforms.functional.crop( | |
frame, | |
top=frame_bbox[1], | |
left=frame_bbox[0], | |
height=frame_bbox[3] - frame_bbox[1], | |
width=frame_bbox[2] - frame_bbox[0]) | |
cropped_frame = torchvision.transforms.functional.resize( | |
cropped_frame, (1024, 1024), antialias=True) | |
# Apply matting | |
cropped_frame, mask = self.matting_engine(cropped_frame / 255.0, | |
return_type='matting', | |
background_rgb=1.0) | |
cropped_frame = cropped_frame.cpu() * 255.0 | |
saved_image = np.round(cropped_frame.cpu().permute( | |
1, 2, 0).numpy()).astype(np.uint8)[:, :, (2, 1, 0)] | |
# Create output directories if not exist | |
self.sub_output_dir = os.path.join( | |
self.output_preprocess, | |
os.path.splitext(os.path.basename(input_image_path))[0]) | |
output_image_dir = os.path.join(self.sub_output_dir, 'images') | |
output_mask_dir = os.path.join(self.sub_output_dir, 'mask') | |
output_alpha_map_dir = os.path.join(self.sub_output_dir, 'alpha_maps') | |
os.makedirs(output_image_dir, exist_ok=True) | |
os.makedirs(output_mask_dir, exist_ok=True) | |
os.makedirs(output_alpha_map_dir, exist_ok=True) | |
# Save processed image, mask and alpha map | |
cv2.imwrite(os.path.join(output_image_dir, name_list[frame_index]), | |
saved_image) | |
cv2.imwrite(os.path.join(output_mask_dir, name_list[frame_index]), | |
np.array((mask.cpu() * 255.0)).astype(np.uint8)) | |
cv2.imwrite( | |
os.path.join(output_alpha_map_dir, | |
name_list[frame_index]).replace('.png', '.jpg'), | |
(np.ones_like(saved_image) * 255).astype(np.uint8)) | |
# Landmark detection | |
detections, _ = self.detector.detect(saved_image, 0.8, 1) | |
for idx, detection in enumerate(detections): | |
x1_ori, y1_ori = detection[2], detection[3] | |
x2_ori, y2_ori = x1_ori + detection[4], y1_ori + detection[5] | |
scale = max(x2_ori - x1_ori, y2_ori - y1_ori) / 180 | |
center_w, center_h = (x1_ori + x2_ori) / 2, (y1_ori + y2_ori) / 2 | |
scale, center_w, center_h = float(scale), float(center_w), float( | |
center_h) | |
face_landmarks = self.alignment.analyze(saved_image, scale, | |
center_w, center_h) | |
# Normalize and save landmarks | |
normalized_landmarks = np.zeros((face_landmarks.shape[0], 3)) | |
normalized_landmarks[:, :2] = face_landmarks / 1024 | |
landmark_output_dir = os.path.join(self.sub_output_dir, 'landmark2d') | |
os.makedirs(landmark_output_dir, exist_ok=True) | |
landmark_data = { | |
'bounding_box': [], | |
'face_landmark_2d': normalized_landmarks[None, ...], | |
} | |
landmark_path = os.path.join(landmark_output_dir, 'landmarks.npz') | |
np.savez(landmark_path, **landmark_data) | |
if self.detect_iris_landmarks_flag: | |
self._detect_iris_landmarks( | |
os.path.join(output_image_dir, name_list[frame_index])) | |
end_time = time.time() | |
torch.cuda.empty_cache() | |
logger.info( | |
f'Finished Processing Image. Time: {end_time - start_time:.2f}s') | |
return 0 | |
def optimize(self): | |
"""Optimize the tracking model using configuration data.""" | |
start_time = time.time() | |
logger.info('Starting Optimization...') | |
tyro.extras.set_accent_color('bright_yellow') | |
config_data = tyro.cli(BaseTrackingConfig) | |
config_data.data.sequence = self.sub_output_dir.split('/')[-1] | |
config_data.data.root_folder = Path( | |
os.path.dirname(self.sub_output_dir)) | |
if not os.path.exists(self.sub_output_dir): | |
logger.error(f'Failed to load {self.sub_output_dir}') | |
return ERROR_CODE['FailedToOptimize'] | |
config_data.exp.output_folder = Path(self.output_tracking) | |
tracker = GlobalTracker(config_data) | |
tracker.optimize() | |
end_time = time.time() | |
torch.cuda.empty_cache() | |
logger.info( | |
f'Finished Optimization. Time: {end_time - start_time:.2f}s') | |
return 0 | |
def _detect_iris_landmarks(self, image_path): | |
"""Detect iris landmarks in the given image.""" | |
from fdlite import face_detection_to_roi, iris_roi_from_face_landmarks | |
img = Image.open(image_path) | |
img_size = (1024, 1024) | |
face_detections = self.iris_detect_faces(img) | |
if len(face_detections) != 1: | |
logger.warning('Empty iris landmarks') | |
else: | |
face_detection = face_detections[0] | |
try: | |
face_roi = face_detection_to_roi(face_detection, img_size) | |
except ValueError: | |
logger.warning('Empty iris landmarks') | |
return | |
face_landmarks = self.iris_detect_face_landmarks(img, face_roi) | |
if len(face_landmarks) == 0: | |
logger.warning('Empty iris landmarks') | |
return | |
iris_rois = iris_roi_from_face_landmarks(face_landmarks, img_size) | |
if len(iris_rois) != 2: | |
logger.warning('Empty iris landmarks') | |
return | |
landmarks = [] | |
for iris_roi in iris_rois[::-1]: | |
try: | |
iris_landmarks = self.iris_detect_iris_landmarks( | |
img, iris_roi).iris[0:1] | |
except np.linalg.LinAlgError: | |
logger.warning('Failed to get iris landmarks') | |
break | |
# For each landmark, append x and y coordinates scaled to 1024. | |
for landmark in iris_landmarks: | |
landmarks.append(landmark.x * 1024) | |
landmarks.append(landmark.y * 1024) | |
landmark_data = {'00000.png': landmarks} | |
json.dump( | |
landmark_data, | |
open( | |
os.path.join(self.sub_output_dir, 'landmark2d', | |
'iris.json'), 'w')) | |
def export(self): | |
"""Export the tracking results to configured folder.""" | |
logger.info(f'Beginning export from {self.output_tracking}') | |
start_time = time.time() | |
if not os.path.exists(self.output_tracking): | |
logger.error(f'Failed to load {self.output_tracking}') | |
return ERROR_CODE['FailedToExport'], 'Failed' | |
src_folder = Path(self.output_tracking) | |
tgt_folder = Path(self.output_export, | |
self.sub_output_dir.split('/')[-1]) | |
src_folder, config_data = load_config(src_folder) | |
nerf_writer = NeRFDatasetWriter(config_data.data, tgt_folder, None, | |
None, 'white') | |
nerf_writer.write() | |
flame_writer = TrackedFLAMEDatasetWriter(config_data.model, | |
src_folder, | |
tgt_folder, | |
mode='param', | |
epoch=-1) | |
flame_writer.write() | |
split_json(tgt_folder) | |
end_time = time.time() | |
torch.cuda.empty_cache() | |
logger.info(f'Finished Export. Time: {end_time - start_time:.2f}s') | |
return 0, str(tgt_folder) | |