from PIL import Image,ImageSequence import numpy as np import torch from moviepy.editor import VideoFileClip import os import imageio import random from diffusers.utils import export_to_video def resize_and_center_crop(image, target_width, target_height): pil_image = Image.fromarray(image) original_width, original_height = pil_image.size scale_factor = max(target_width / original_width, target_height / original_height) resized_width = int(round(original_width * scale_factor)) resized_height = int(round(original_height * scale_factor)) resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS) left = (resized_width - target_width) / 2 top = (resized_height - target_height) / 2 right = (resized_width + target_width) / 2 bottom = (resized_height + target_height) / 2 cropped_image = resized_image.crop((left, top, right, bottom)) return np.array(cropped_image) def numpy2pytorch(imgs, device, dtype): h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 h = h.movedim(-1, 1) return h.to(device=device, dtype=dtype) def get_fg_video(video_list, mask_list, device, dtype): video_np = np.stack(video_list, axis=0) mask_np = np.stack(mask_list, axis=0) mask_bool = mask_np == 255 video_fg = np.where(mask_bool, video_np, 127) h = torch.from_numpy(video_fg).float() / 127.0 - 1.0 h = h.movedim(-1, 1) return h.to(device=device, dtype=dtype) def pad(x, p, i): return x[:i] if len(x) >= i else x + [p] * (i - len(x)) def gif_to_mp4(gif_path, mp4_path): clip = VideoFileClip(gif_path) clip.write_videofile(mp4_path) def generate_light_sequence(light_tensor, num_frames=16, direction="r"): if direction in "l": target_tensor = torch.rot90(light_tensor, k=1, dims=(2, 3)) elif direction in "r": target_tensor = torch.rot90(light_tensor, k=-1, dims=(2, 3)) else: raise ValueError("direction must be either 'r' for right or 'l' for left") # Generate the sequence out_list = [] for frame_idx in range(num_frames): t = frame_idx / (num_frames - 1) interpolated_matrix = (1 - t) * light_tensor + t * target_tensor out_list.append(interpolated_matrix) out_tensor = torch.stack(out_list, dim=0).squeeze(1) return out_tensor def tensor2vid(video: torch.Tensor, processor, output_type="np"): batch_size, channels, num_frames, height, width = video.shape ## [1, 4, 16, 512, 512] outputs = [] for batch_idx in range(batch_size): batch_vid = video[batch_idx].permute(1, 0, 2, 3) batch_output = processor.postprocess(batch_vid, output_type) outputs.append(batch_output) return outputs def read_video(video_path:str, image_width, image_height): extension = video_path.split('.')[-1].lower() video_name = os.path.basename(video_path) video_list = [] if extension in "gif": ## input from gif video = Image.open(video_path) for i, frame in enumerate(ImageSequence.Iterator(video)): frame = np.array(frame.convert("RGB")) frame = resize_and_center_crop(frame, image_width, image_height) video_list.append(frame) elif extension in "mp4": ## input from mp4 reader = imageio.get_reader(video_path) for frame in reader: frame = resize_and_center_crop(frame, image_width, image_height) video_list.append(frame) else: raise ValueError('Wrong input type') video_list = [Image.fromarray(frame) for frame in video_list] return video_list, video_name def read_mask(mask_folder:str): mask_files = os.listdir(mask_folder) mask_files = sorted(mask_files) mask_list = [] for mask_file in mask_files: mask_path = os.path.join(mask_folder, mask_file) mask = Image.open(mask_path).convert('RGB') mask_list.append(mask) return mask_list def decode_latents(vae, latents, decode_chunk_size: int = 16): latents = 1 / vae.config.scaling_factor * latents video = [] for i in range(0, latents.shape[0], decode_chunk_size): batch_latents = latents[i : i + decode_chunk_size] batch_latents = vae.decode(batch_latents).sample video.append(batch_latents) video = torch.cat(video) return video def encode_video(vae, video, decode_chunk_size: int = 16) -> torch.Tensor: latents = [] for i in range(0, len(video), decode_chunk_size): batch_video = video[i : i + decode_chunk_size] batch_video = vae.encode(batch_video).latent_dist.mode() latents.append(batch_video) return torch.cat(latents) def vis_video(input_video, video_processor, save_path): ## shape: 1, c, f, h, w relight_video = video_processor.postprocess_video(video=input_video, output_type="pil") export_to_video(relight_video[0], save_path) def set_all_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True