Spaces:
Running
Running
import torch | |
import numpy as np | |
from enum import Enum | |
import math | |
import torch.nn.functional as F | |
from utils.tools import resize_and_center_crop, numpy2pytorch, pad, decode_latents, encode_video | |
class BGSource(Enum): | |
NONE = "None" | |
LEFT = "Left Light" | |
RIGHT = "Right Light" | |
TOP = "Top Light" | |
BOTTOM = "Bottom Light" | |
class Relighter: | |
def __init__(self, | |
pipeline, | |
relight_prompt="", | |
num_frames=16, | |
image_width=512, | |
image_height=512, | |
num_samples=1, | |
steps=15, | |
cfg=2, | |
lowres_denoise=0.9, | |
bg_source=BGSource.RIGHT, | |
generator=None, | |
): | |
self.pipeline = pipeline | |
self.image_width = image_width | |
self.image_height = image_height | |
self.num_samples = num_samples | |
self.steps = steps | |
self.cfg = cfg | |
self.lowres_denoise = lowres_denoise | |
self.bg_source = bg_source | |
self.generator = generator | |
self.device = pipeline.device | |
self.num_frames = num_frames | |
self.vae = self.pipeline.vae | |
self.a_prompt = "best quality" | |
self.n_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" | |
positive_prompt = relight_prompt + ', ' + self.a_prompt | |
negative_prompt = self.n_prompt | |
tokenizer = self.pipeline.tokenizer | |
device = self.pipeline.device | |
vae = self.vae | |
conds, unconds = self.encode_prompt_pair(tokenizer, device, positive_prompt, negative_prompt) | |
input_bg = self.create_background() | |
bg = resize_and_center_crop(input_bg, self.image_width, self.image_height) | |
bg_latent = numpy2pytorch([bg], device, vae.dtype) | |
bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor | |
self.bg_latent = bg_latent.repeat(self.num_frames, 1, 1, 1) ## 固定光源 | |
self.conds = conds.repeat(self.num_frames, 1, 1) | |
self.unconds = unconds.repeat(self.num_frames, 1, 1) | |
def encode_prompt_inner(self, tokenizer, txt): | |
max_length = tokenizer.model_max_length | |
chunk_length = tokenizer.model_max_length - 2 | |
id_start = tokenizer.bos_token_id | |
id_end = tokenizer.eos_token_id | |
id_pad = id_end | |
tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"] | |
chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)] | |
chunks = [pad(ck, id_pad, max_length) for ck in chunks] | |
token_ids = torch.tensor(chunks).to(device=self.device, dtype=torch.int64) | |
conds = self.pipeline.text_encoder(token_ids).last_hidden_state | |
return conds | |
def encode_prompt_pair(self, tokenizer, device, positive_prompt, negative_prompt): | |
c = self.encode_prompt_inner(tokenizer, positive_prompt) | |
uc = self.encode_prompt_inner(tokenizer, negative_prompt) | |
c_len = float(len(c)) | |
uc_len = float(len(uc)) | |
max_count = max(c_len, uc_len) | |
c_repeat = int(math.ceil(max_count / c_len)) | |
uc_repeat = int(math.ceil(max_count / uc_len)) | |
max_chunk = max(len(c), len(uc)) | |
c = torch.cat([c] * c_repeat, dim=0)[:max_chunk] | |
uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk] | |
c = torch.cat([p[None, ...] for p in c], dim=1) | |
uc = torch.cat([p[None, ...] for p in uc], dim=1) | |
return c.to(device), uc.to(device) | |
def create_background(self): | |
max_pix = 255 | |
min_pix = 0 | |
print(f"max light pix:{max_pix}, min light pix:{min_pix}") | |
if self.bg_source == BGSource.NONE: | |
return None | |
elif self.bg_source == BGSource.LEFT: | |
gradient = np.linspace(max_pix, min_pix, self.image_width) | |
image = np.tile(gradient, (self.image_height, 1)) | |
return np.stack((image,) * 3, axis=-1).astype(np.uint8) | |
elif self.bg_source == BGSource.RIGHT: | |
gradient = np.linspace(min_pix, max_pix, self.image_width) | |
image = np.tile(gradient, (self.image_height, 1)) | |
return np.stack((image,) * 3, axis=-1).astype(np.uint8) | |
elif self.bg_source == BGSource.TOP: | |
gradient = np.linspace(max_pix, min_pix, self.image_height)[:, None] | |
image = np.tile(gradient, (1, self.image_width)) | |
return np.stack((image,) * 3, axis=-1).astype(np.uint8) | |
elif self.bg_source == BGSource.BOTTOM: | |
gradient = np.linspace(min_pix, max_pix, self.image_height)[:, None] | |
image = np.tile(gradient, (1, self.image_width)) | |
return np.stack((image,) * 3, axis=-1).astype(np.uint8) | |
else: | |
raise ValueError('Wrong initial latent!') | |
def __call__(self, input_video, init_latent=None, input_strength=None): | |
input_latent = encode_video(self.vae, input_video)* self.vae.config.scaling_factor | |
if input_strength: | |
light_strength = input_strength | |
else: | |
light_strength = self.lowres_denoise | |
if not init_latent: | |
init_latent = self.bg_latent | |
latents = self.pipeline( | |
image=init_latent, | |
strength=light_strength, | |
prompt_embeds=self.conds, | |
negative_prompt_embeds=self.unconds, | |
width=self.image_width, | |
height=self.image_height, | |
num_inference_steps=int(round(self.steps / self.lowres_denoise)), | |
num_images_per_prompt=self.num_samples, | |
generator=self.generator, | |
output_type='latent', | |
guidance_scale=self.cfg, | |
cross_attention_kwargs={'concat_conds': input_latent}, | |
).images.to(self.pipeline.vae.dtype) | |
relight_video = decode_latents(self.vae, latents) | |
return relight_video |