import torch import numpy as np from enum import Enum import math import torch.nn.functional as F from utils.tools import resize_and_center_crop, numpy2pytorch, pad, decode_latents, encode_video class BGSource(Enum): NONE = "None" LEFT = "Left Light" RIGHT = "Right Light" TOP = "Top Light" BOTTOM = "Bottom Light" class Relighter: def __init__(self, pipeline, relight_prompt="", num_frames=16, image_width=512, image_height=512, num_samples=1, steps=15, cfg=2, lowres_denoise=0.9, bg_source=BGSource.RIGHT, generator=None, ): self.pipeline = pipeline self.image_width = image_width self.image_height = image_height self.num_samples = num_samples self.steps = steps self.cfg = cfg self.lowres_denoise = lowres_denoise self.bg_source = bg_source self.generator = generator self.device = pipeline.device self.num_frames = num_frames self.vae = self.pipeline.vae self.a_prompt = "best quality" self.n_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" positive_prompt = relight_prompt + ', ' + self.a_prompt negative_prompt = self.n_prompt tokenizer = self.pipeline.tokenizer device = self.pipeline.device vae = self.vae conds, unconds = self.encode_prompt_pair(tokenizer, device, positive_prompt, negative_prompt) input_bg = self.create_background() bg = resize_and_center_crop(input_bg, self.image_width, self.image_height) bg_latent = numpy2pytorch([bg], device, vae.dtype) bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor self.bg_latent = bg_latent.repeat(self.num_frames, 1, 1, 1) ## 固定光源 self.conds = conds.repeat(self.num_frames, 1, 1) self.unconds = unconds.repeat(self.num_frames, 1, 1) def encode_prompt_inner(self, tokenizer, txt): max_length = tokenizer.model_max_length chunk_length = tokenizer.model_max_length - 2 id_start = tokenizer.bos_token_id id_end = tokenizer.eos_token_id id_pad = id_end tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"] chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)] chunks = [pad(ck, id_pad, max_length) for ck in chunks] token_ids = torch.tensor(chunks).to(device=self.device, dtype=torch.int64) conds = self.pipeline.text_encoder(token_ids).last_hidden_state return conds def encode_prompt_pair(self, tokenizer, device, positive_prompt, negative_prompt): c = self.encode_prompt_inner(tokenizer, positive_prompt) uc = self.encode_prompt_inner(tokenizer, negative_prompt) c_len = float(len(c)) uc_len = float(len(uc)) max_count = max(c_len, uc_len) c_repeat = int(math.ceil(max_count / c_len)) uc_repeat = int(math.ceil(max_count / uc_len)) max_chunk = max(len(c), len(uc)) c = torch.cat([c] * c_repeat, dim=0)[:max_chunk] uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk] c = torch.cat([p[None, ...] for p in c], dim=1) uc = torch.cat([p[None, ...] for p in uc], dim=1) return c.to(device), uc.to(device) def create_background(self): max_pix = 255 min_pix = 0 print(f"max light pix:{max_pix}, min light pix:{min_pix}") if self.bg_source == BGSource.NONE: return None elif self.bg_source == BGSource.LEFT: gradient = np.linspace(max_pix, min_pix, self.image_width) image = np.tile(gradient, (self.image_height, 1)) return np.stack((image,) * 3, axis=-1).astype(np.uint8) elif self.bg_source == BGSource.RIGHT: gradient = np.linspace(min_pix, max_pix, self.image_width) image = np.tile(gradient, (self.image_height, 1)) return np.stack((image,) * 3, axis=-1).astype(np.uint8) elif self.bg_source == BGSource.TOP: gradient = np.linspace(max_pix, min_pix, self.image_height)[:, None] image = np.tile(gradient, (1, self.image_width)) return np.stack((image,) * 3, axis=-1).astype(np.uint8) elif self.bg_source == BGSource.BOTTOM: gradient = np.linspace(min_pix, max_pix, self.image_height)[:, None] image = np.tile(gradient, (1, self.image_width)) return np.stack((image,) * 3, axis=-1).astype(np.uint8) else: raise ValueError('Wrong initial latent!') @torch.no_grad() def __call__(self, input_video, init_latent=None, input_strength=None): input_latent = encode_video(self.vae, input_video)* self.vae.config.scaling_factor if input_strength: light_strength = input_strength else: light_strength = self.lowres_denoise if not init_latent: init_latent = self.bg_latent latents = self.pipeline( image=init_latent, strength=light_strength, prompt_embeds=self.conds, negative_prompt_embeds=self.unconds, width=self.image_width, height=self.image_height, num_inference_steps=int(round(self.steps / self.lowres_denoise)), num_images_per_prompt=self.num_samples, generator=self.generator, output_type='latent', guidance_scale=self.cfg, cross_attention_kwargs={'concat_conds': input_latent}, ).images.to(self.pipeline.vae.dtype) relight_video = decode_latents(self.vae, latents) return relight_video