Spaces:
Runtime error
Runtime error
| # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms.functional as TF | |
| from PIL import Image | |
| class VaceImageProcessor(object): | |
| def __init__(self, downsample=None, seq_len=None): | |
| self.downsample = downsample | |
| self.seq_len = seq_len | |
| def _pillow_convert(self, image, cvt_type='RGB'): | |
| if image.mode != cvt_type: | |
| if image.mode == 'P': | |
| image = image.convert(f'{cvt_type}A') | |
| if image.mode == f'{cvt_type}A': | |
| bg = Image.new( | |
| cvt_type, | |
| size=(image.width, image.height), | |
| color=(255, 255, 255)) | |
| bg.paste(image, (0, 0), mask=image) | |
| image = bg | |
| else: | |
| image = image.convert(cvt_type) | |
| return image | |
| def _load_image(self, img_path): | |
| if img_path is None or img_path == '': | |
| return None | |
| img = Image.open(img_path) | |
| img = self._pillow_convert(img) | |
| return img | |
| def _resize_crop(self, img, oh, ow, normalize=True): | |
| """ | |
| Resize, center crop, convert to tensor, and normalize. | |
| """ | |
| # resize and crop | |
| iw, ih = img.size | |
| if iw != ow or ih != oh: | |
| # resize | |
| scale = max(ow / iw, oh / ih) | |
| img = img.resize((round(scale * iw), round(scale * ih)), | |
| resample=Image.Resampling.LANCZOS) | |
| assert img.width >= ow and img.height >= oh | |
| # center crop | |
| x1 = (img.width - ow) // 2 | |
| y1 = (img.height - oh) // 2 | |
| img = img.crop((x1, y1, x1 + ow, y1 + oh)) | |
| # normalize | |
| if normalize: | |
| img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1) | |
| return img | |
| def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs): | |
| return self._resize_crop(img, oh, ow, normalize) | |
| def load_image(self, data_key, **kwargs): | |
| return self.load_image_batch(data_key, **kwargs) | |
| def load_image_pair(self, data_key, data_key2, **kwargs): | |
| return self.load_image_batch(data_key, data_key2, **kwargs) | |
| def load_image_batch(self, | |
| *data_key_batch, | |
| normalize=True, | |
| seq_len=None, | |
| **kwargs): | |
| seq_len = self.seq_len if seq_len is None else seq_len | |
| imgs = [] | |
| for data_key in data_key_batch: | |
| img = self._load_image(data_key) | |
| imgs.append(img) | |
| w, h = imgs[0].size | |
| dh, dw = self.downsample[1:] | |
| # compute output size | |
| scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw)))) | |
| oh = int(h * scale) // dh * dh | |
| ow = int(w * scale) // dw * dw | |
| assert (oh // dh) * (ow // dw) <= seq_len | |
| imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs] | |
| return *imgs, (oh, ow) | |
| class VaceVideoProcessor(object): | |
| def __init__(self, downsample, min_area, max_area, min_fps, max_fps, | |
| zero_start, seq_len, keep_last, **kwargs): | |
| self.downsample = downsample | |
| self.min_area = min_area | |
| self.max_area = max_area | |
| self.min_fps = min_fps | |
| self.max_fps = max_fps | |
| self.zero_start = zero_start | |
| self.keep_last = keep_last | |
| self.seq_len = seq_len | |
| assert seq_len >= min_area / (self.downsample[1] * self.downsample[2]) | |
| def set_area(self, area): | |
| self.min_area = area | |
| self.max_area = area | |
| def set_seq_len(self, seq_len): | |
| self.seq_len = seq_len | |
| def resize_crop(video: torch.Tensor, oh: int, ow: int): | |
| """ | |
| Resize, center crop and normalize for decord loaded video (torch.Tensor type) | |
| Parameters: | |
| video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C) | |
| oh - target height (int) | |
| ow - target width (int) | |
| Returns: | |
| The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W) | |
| Raises: | |
| """ | |
| # permute ([t, h, w, c] -> [t, c, h, w]) | |
| video = video.permute(0, 3, 1, 2) | |
| # resize and crop | |
| ih, iw = video.shape[2:] | |
| if ih != oh or iw != ow: | |
| # resize | |
| scale = max(ow / iw, oh / ih) | |
| video = F.interpolate( | |
| video, | |
| size=(round(scale * ih), round(scale * iw)), | |
| mode='bicubic', | |
| antialias=True) | |
| assert video.size(3) >= ow and video.size(2) >= oh | |
| # center crop | |
| x1 = (video.size(3) - ow) // 2 | |
| y1 = (video.size(2) - oh) // 2 | |
| video = video[:, :, y1:y1 + oh, x1:x1 + ow] | |
| # permute ([t, c, h, w] -> [c, t, h, w]) and normalize | |
| video = video.transpose(0, 1).float().div_(127.5).sub_(1.) | |
| return video | |
| def _video_preprocess(self, video, oh, ow): | |
| return self.resize_crop(video, oh, ow) | |
| def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, | |
| rng): | |
| target_fps = min(fps, self.max_fps) | |
| duration = frame_timestamps[-1].mean() | |
| x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box | |
| h, w = y2 - y1, x2 - x1 | |
| ratio = h / w | |
| df, dh, dw = self.downsample | |
| area_z = min(self.seq_len, self.max_area / (dh * dw), | |
| (h // dh) * (w // dw)) | |
| of = min((int(duration * target_fps) - 1) // df + 1, | |
| int(self.seq_len / area_z)) | |
| # deduce target shape of the [latent video] | |
| target_area_z = min(area_z, int(self.seq_len / of)) | |
| oh = round(np.sqrt(target_area_z * ratio)) | |
| ow = int(target_area_z / oh) | |
| of = (of - 1) * df + 1 | |
| oh *= dh | |
| ow *= dw | |
| # sample frame ids | |
| target_duration = of / target_fps | |
| begin = 0. if self.zero_start else rng.uniform( | |
| 0, duration - target_duration) | |
| timestamps = np.linspace(begin, begin + target_duration, of) | |
| frame_ids = np.argmax( | |
| np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0], | |
| timestamps[:, None] < frame_timestamps[None, :, 1]), | |
| axis=1).tolist() | |
| return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps | |
| def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, | |
| crop_box, rng): | |
| duration = frame_timestamps[-1].mean() | |
| x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box | |
| h, w = y2 - y1, x2 - x1 | |
| ratio = h / w | |
| df, dh, dw = self.downsample | |
| area_z = min(self.seq_len, self.max_area / (dh * dw), | |
| (h // dh) * (w // dw)) | |
| of = min((len(frame_timestamps) - 1) // df + 1, | |
| int(self.seq_len / area_z)) | |
| # deduce target shape of the [latent video] | |
| target_area_z = min(area_z, int(self.seq_len / of)) | |
| oh = round(np.sqrt(target_area_z * ratio)) | |
| ow = int(target_area_z / oh) | |
| of = (of - 1) * df + 1 | |
| oh *= dh | |
| ow *= dw | |
| # sample frame ids | |
| target_duration = duration | |
| target_fps = of / target_duration | |
| timestamps = np.linspace(0., target_duration, of) | |
| frame_ids = np.argmax( | |
| np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0], | |
| timestamps[:, None] <= frame_timestamps[None, :, 1]), | |
| axis=1).tolist() | |
| # print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids)) | |
| return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps | |
| def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng): | |
| if self.keep_last: | |
| return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, | |
| w, crop_box, rng) | |
| else: | |
| return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, | |
| crop_box, rng) | |
| def load_video(self, data_key, crop_box=None, seed=2024, **kwargs): | |
| return self.load_video_batch( | |
| data_key, crop_box=crop_box, seed=seed, **kwargs) | |
| def load_video_pair(self, | |
| data_key, | |
| data_key2, | |
| crop_box=None, | |
| seed=2024, | |
| **kwargs): | |
| return self.load_video_batch( | |
| data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) | |
| def load_video_batch(self, | |
| *data_key_batch, | |
| crop_box=None, | |
| seed=2024, | |
| **kwargs): | |
| rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) | |
| # read video | |
| import decord | |
| decord.bridge.set_bridge('torch') | |
| readers = [] | |
| for data_k in data_key_batch: | |
| reader = decord.VideoReader(data_k) | |
| readers.append(reader) | |
| fps = readers[0].get_avg_fps() | |
| length = min([len(r) for r in readers]) | |
| frame_timestamps = [ | |
| readers[0].get_frame_timestamp(i) for i in range(length) | |
| ] | |
| frame_timestamps = np.array(frame_timestamps, dtype=np.float32) | |
| h, w = readers[0].next().shape[:2] | |
| frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox( | |
| fps, frame_timestamps, h, w, crop_box, rng) | |
| # preprocess video | |
| videos = [ | |
| reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] | |
| for reader in readers | |
| ] | |
| videos = [self._video_preprocess(video, oh, ow) for video in videos] | |
| return *videos, frame_ids, (oh, ow), fps | |
| # return videos if len(videos) > 1 else videos[0] | |
| def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, | |
| device): | |
| for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): | |
| if sub_src_video is None and sub_src_mask is None: | |
| src_video[i] = torch.zeros( | |
| (3, num_frames, image_size[0], image_size[1]), device=device) | |
| src_mask[i] = torch.ones( | |
| (1, num_frames, image_size[0], image_size[1]), device=device) | |
| for i, ref_images in enumerate(src_ref_images): | |
| if ref_images is not None: | |
| for j, ref_img in enumerate(ref_images): | |
| if ref_img is not None and ref_img.shape[-2:] != image_size: | |
| canvas_height, canvas_width = image_size | |
| ref_height, ref_width = ref_img.shape[-2:] | |
| white_canvas = torch.ones( | |
| (3, 1, canvas_height, canvas_width), | |
| device=device) # [-1, 1] | |
| scale = min(canvas_height / ref_height, | |
| canvas_width / ref_width) | |
| new_height = int(ref_height * scale) | |
| new_width = int(ref_width * scale) | |
| resized_image = F.interpolate( | |
| ref_img.squeeze(1).unsqueeze(0), | |
| size=(new_height, new_width), | |
| mode='bilinear', | |
| align_corners=False).squeeze(0).unsqueeze(1) | |
| top = (canvas_height - new_height) // 2 | |
| left = (canvas_width - new_width) // 2 | |
| white_canvas[:, :, top:top + new_height, | |
| left:left + new_width] = resized_image | |
| src_ref_images[i][j] = white_canvas | |
| return src_video, src_mask, src_ref_images | |