|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torchvision import transforms |
|
import cv2 |
|
from einops import rearrange |
|
import mediapipe as mp |
|
import torch |
|
import numpy as np |
|
from typing import Union |
|
from .affine_transform import AlignRestore, laplacianSmooth |
|
import face_alignment |
|
|
|
""" |
|
If you are enlarging the image, you should prefer to use INTER_LINEAR or INTER_CUBIC interpolation. If you are shrinking the image, you should prefer to use INTER_AREA interpolation. |
|
https://stackoverflow.com/questions/23853632/which-kind-of-interpolation-best-for-resizing-image |
|
""" |
|
|
|
|
|
def load_fixed_mask(resolution: int) -> torch.Tensor: |
|
mask_image = cv2.imread("latentsync/utils/mask.png") |
|
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB) |
|
mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_AREA) / 255.0 |
|
mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w") |
|
return mask_image |
|
|
|
|
|
class ImageProcessor: |
|
def __init__(self, resolution: int = 512, mask: str = "fix_mask", device: str = "cpu", mask_image=None): |
|
self.resolution = resolution |
|
self.resize = transforms.Resize( |
|
(resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True |
|
) |
|
self.normalize = transforms.Normalize([0.5], [0.5], inplace=True) |
|
self.mask = mask |
|
|
|
if mask in ["mouth", "face", "eye"]: |
|
self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True) |
|
if mask == "fix_mask": |
|
self.face_mesh = None |
|
self.smoother = laplacianSmooth() |
|
self.restorer = AlignRestore() |
|
|
|
if mask_image is None: |
|
self.mask_image = load_fixed_mask(resolution) |
|
else: |
|
self.mask_image = mask_image |
|
|
|
if device != "cpu": |
|
self.fa = face_alignment.FaceAlignment( |
|
face_alignment.LandmarksType.TWO_D, flip_input=False, device=device |
|
) |
|
self.face_mesh = None |
|
else: |
|
|
|
self.face_mesh = None |
|
self.fa = None |
|
|
|
def detect_facial_landmarks(self, image: np.ndarray): |
|
height, width, _ = image.shape |
|
results = self.face_mesh.process(image) |
|
if not results.multi_face_landmarks: |
|
raise RuntimeError("Face not detected") |
|
face_landmarks = results.multi_face_landmarks[0] |
|
landmark_coordinates = [ |
|
(int(landmark.x * width), int(landmark.y * height)) for landmark in face_landmarks.landmark |
|
] |
|
return landmark_coordinates |
|
|
|
def preprocess_one_masked_image(self, image: torch.Tensor) -> np.ndarray: |
|
image = self.resize(image) |
|
|
|
if self.mask == "mouth" or self.mask == "face": |
|
landmark_coordinates = self.detect_facial_landmarks(image) |
|
if self.mask == "mouth": |
|
surround_landmarks = mouth_surround_landmarks |
|
else: |
|
surround_landmarks = face_surround_landmarks |
|
|
|
points = [landmark_coordinates[landmark] for landmark in surround_landmarks] |
|
points = np.array(points) |
|
mask = np.ones((self.resolution, self.resolution)) |
|
mask = cv2.fillPoly(mask, pts=[points], color=(0, 0, 0)) |
|
mask = torch.from_numpy(mask) |
|
mask = mask.unsqueeze(0) |
|
elif self.mask == "half": |
|
mask = torch.ones((self.resolution, self.resolution)) |
|
height = mask.shape[0] |
|
mask[height // 2 :, :] = 0 |
|
mask = mask.unsqueeze(0) |
|
elif self.mask == "eye": |
|
mask = torch.ones((self.resolution, self.resolution)) |
|
landmark_coordinates = self.detect_facial_landmarks(image) |
|
y = landmark_coordinates[195][1] |
|
mask[y:, :] = 0 |
|
mask = mask.unsqueeze(0) |
|
else: |
|
raise ValueError("Invalid mask type") |
|
|
|
image = image.to(dtype=torch.float32) |
|
pixel_values = self.normalize(image / 255.0) |
|
masked_pixel_values = pixel_values * mask |
|
mask = 1 - mask |
|
|
|
return pixel_values, masked_pixel_values, mask |
|
|
|
def affine_transform(self, image: torch.Tensor) -> np.ndarray: |
|
|
|
if self.fa is None: |
|
landmark_coordinates = np.array(self.detect_facial_landmarks(image)) |
|
lm68 = mediapipe_lm478_to_face_alignment_lm68(landmark_coordinates) |
|
else: |
|
detected_faces = self.fa.get_landmarks(image) |
|
if detected_faces is None: |
|
raise RuntimeError("Face not detected") |
|
lm68 = detected_faces[0] |
|
|
|
points = self.smoother.smooth(lm68) |
|
lmk3_ = np.zeros((3, 2)) |
|
lmk3_[0] = points[17:22].mean(0) |
|
lmk3_[1] = points[22:27].mean(0) |
|
lmk3_[2] = points[27:36].mean(0) |
|
|
|
face, affine_matrix = self.restorer.align_warp_face( |
|
image.copy(), lmks3=lmk3_, smooth=True, border_mode="constant" |
|
) |
|
box = [0, 0, face.shape[1], face.shape[0]] |
|
face = cv2.resize(face, (self.resolution, self.resolution), interpolation=cv2.INTER_CUBIC) |
|
face = rearrange(torch.from_numpy(face), "h w c -> c h w") |
|
return face, box, affine_matrix |
|
|
|
def preprocess_fixed_mask_image(self, image: torch.Tensor, affine_transform=False): |
|
if affine_transform: |
|
image, _, _ = self.affine_transform(image) |
|
else: |
|
image = self.resize(image) |
|
pixel_values = self.normalize(image / 255.0) |
|
masked_pixel_values = pixel_values * self.mask_image |
|
return pixel_values, masked_pixel_values, self.mask_image[0:1] |
|
|
|
def prepare_masks_and_masked_images(self, images: Union[torch.Tensor, np.ndarray], affine_transform=False): |
|
if isinstance(images, np.ndarray): |
|
images = torch.from_numpy(images) |
|
if images.shape[3] == 3: |
|
images = rearrange(images, "b h w c -> b c h w") |
|
if self.mask == "fix_mask": |
|
results = [self.preprocess_fixed_mask_image(image, affine_transform=affine_transform) for image in images] |
|
else: |
|
results = [self.preprocess_one_masked_image(image) for image in images] |
|
|
|
pixel_values_list, masked_pixel_values_list, masks_list = list(zip(*results)) |
|
return torch.stack(pixel_values_list), torch.stack(masked_pixel_values_list), torch.stack(masks_list) |
|
|
|
def process_images(self, images: Union[torch.Tensor, np.ndarray]): |
|
if isinstance(images, np.ndarray): |
|
images = torch.from_numpy(images) |
|
if images.shape[3] == 3: |
|
images = rearrange(images, "b h w c -> b c h w") |
|
images = self.resize(images) |
|
pixel_values = self.normalize(images / 255.0) |
|
return pixel_values |
|
|
|
def close(self): |
|
if self.face_mesh is not None: |
|
self.face_mesh.close() |
|
|
|
|
|
def mediapipe_lm478_to_face_alignment_lm68(lm478, return_2d=True): |
|
""" |
|
lm478: [B, 478, 3] or [478,3] |
|
""" |
|
|
|
|
|
landmarks_extracted = [] |
|
for index in landmark_points_68: |
|
x = lm478[index][0] |
|
y = lm478[index][1] |
|
landmarks_extracted.append((x, y)) |
|
return np.array(landmarks_extracted) |
|
|
|
|
|
landmark_points_68 = [ |
|
162, |
|
234, |
|
93, |
|
58, |
|
172, |
|
136, |
|
149, |
|
148, |
|
152, |
|
377, |
|
378, |
|
365, |
|
397, |
|
288, |
|
323, |
|
454, |
|
389, |
|
71, |
|
63, |
|
105, |
|
66, |
|
107, |
|
336, |
|
296, |
|
334, |
|
293, |
|
301, |
|
168, |
|
197, |
|
5, |
|
4, |
|
75, |
|
97, |
|
2, |
|
326, |
|
305, |
|
33, |
|
160, |
|
158, |
|
133, |
|
153, |
|
144, |
|
362, |
|
385, |
|
387, |
|
263, |
|
373, |
|
380, |
|
61, |
|
39, |
|
37, |
|
0, |
|
267, |
|
269, |
|
291, |
|
405, |
|
314, |
|
17, |
|
84, |
|
181, |
|
78, |
|
82, |
|
13, |
|
312, |
|
308, |
|
317, |
|
14, |
|
87, |
|
] |
|
|
|
|
|
|
|
mouth_surround_landmarks = [ |
|
164, |
|
165, |
|
167, |
|
92, |
|
186, |
|
57, |
|
43, |
|
106, |
|
182, |
|
83, |
|
18, |
|
313, |
|
406, |
|
335, |
|
273, |
|
287, |
|
410, |
|
322, |
|
391, |
|
393, |
|
] |
|
|
|
face_surround_landmarks = [ |
|
152, |
|
377, |
|
400, |
|
378, |
|
379, |
|
365, |
|
397, |
|
288, |
|
435, |
|
433, |
|
411, |
|
425, |
|
423, |
|
327, |
|
326, |
|
94, |
|
97, |
|
98, |
|
203, |
|
205, |
|
187, |
|
213, |
|
215, |
|
58, |
|
172, |
|
136, |
|
150, |
|
149, |
|
176, |
|
148, |
|
] |
|
|
|
if __name__ == "__main__": |
|
image_processor = ImageProcessor(512, mask="fix_mask") |
|
video = cv2.VideoCapture("/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/original/val/RD_Radio57_000.mp4") |
|
while True: |
|
ret, frame = video.read() |
|
|
|
|
|
|
|
|
|
|
|
frame = rearrange(torch.Tensor(frame).type(torch.uint8), "h w c -> c h w") |
|
|
|
face, _, _ = image_processor.affine_transform(frame) |
|
|
|
break |
|
|
|
face = (rearrange(face, "c h w -> h w c").detach().cpu().numpy()).astype(np.uint8) |
|
cv2.imwrite("face.jpg", face) |
|
|
|
|
|
|
|
|