Light-A-Video / src /ic_light.py
fffiloni's picture
Migrated from GitHub
052f125 verified
import torch
import numpy as np
from enum import Enum
import math
import torch.nn.functional as F
from utils.tools import resize_and_center_crop, numpy2pytorch, pad, decode_latents, encode_video
class BGSource(Enum):
NONE = "None"
LEFT = "Left Light"
RIGHT = "Right Light"
TOP = "Top Light"
BOTTOM = "Bottom Light"
class Relighter:
def __init__(self,
pipeline,
relight_prompt="",
num_frames=16,
image_width=512,
image_height=512,
num_samples=1,
steps=15,
cfg=2,
lowres_denoise=0.9,
bg_source=BGSource.RIGHT,
generator=None,
):
self.pipeline = pipeline
self.image_width = image_width
self.image_height = image_height
self.num_samples = num_samples
self.steps = steps
self.cfg = cfg
self.lowres_denoise = lowres_denoise
self.bg_source = bg_source
self.generator = generator
self.device = pipeline.device
self.num_frames = num_frames
self.vae = self.pipeline.vae
self.a_prompt = "best quality"
self.n_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
positive_prompt = relight_prompt + ', ' + self.a_prompt
negative_prompt = self.n_prompt
tokenizer = self.pipeline.tokenizer
device = self.pipeline.device
vae = self.vae
conds, unconds = self.encode_prompt_pair(tokenizer, device, positive_prompt, negative_prompt)
input_bg = self.create_background()
bg = resize_and_center_crop(input_bg, self.image_width, self.image_height)
bg_latent = numpy2pytorch([bg], device, vae.dtype)
bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
self.bg_latent = bg_latent.repeat(self.num_frames, 1, 1, 1) ## 固定光源
self.conds = conds.repeat(self.num_frames, 1, 1)
self.unconds = unconds.repeat(self.num_frames, 1, 1)
def encode_prompt_inner(self, tokenizer, txt):
max_length = tokenizer.model_max_length
chunk_length = tokenizer.model_max_length - 2
id_start = tokenizer.bos_token_id
id_end = tokenizer.eos_token_id
id_pad = id_end
tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
chunks = [pad(ck, id_pad, max_length) for ck in chunks]
token_ids = torch.tensor(chunks).to(device=self.device, dtype=torch.int64)
conds = self.pipeline.text_encoder(token_ids).last_hidden_state
return conds
def encode_prompt_pair(self, tokenizer, device, positive_prompt, negative_prompt):
c = self.encode_prompt_inner(tokenizer, positive_prompt)
uc = self.encode_prompt_inner(tokenizer, negative_prompt)
c_len = float(len(c))
uc_len = float(len(uc))
max_count = max(c_len, uc_len)
c_repeat = int(math.ceil(max_count / c_len))
uc_repeat = int(math.ceil(max_count / uc_len))
max_chunk = max(len(c), len(uc))
c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
c = torch.cat([p[None, ...] for p in c], dim=1)
uc = torch.cat([p[None, ...] for p in uc], dim=1)
return c.to(device), uc.to(device)
def create_background(self):
max_pix = 255
min_pix = 0
print(f"max light pix:{max_pix}, min light pix:{min_pix}")
if self.bg_source == BGSource.NONE:
return None
elif self.bg_source == BGSource.LEFT:
gradient = np.linspace(max_pix, min_pix, self.image_width)
image = np.tile(gradient, (self.image_height, 1))
return np.stack((image,) * 3, axis=-1).astype(np.uint8)
elif self.bg_source == BGSource.RIGHT:
gradient = np.linspace(min_pix, max_pix, self.image_width)
image = np.tile(gradient, (self.image_height, 1))
return np.stack((image,) * 3, axis=-1).astype(np.uint8)
elif self.bg_source == BGSource.TOP:
gradient = np.linspace(max_pix, min_pix, self.image_height)[:, None]
image = np.tile(gradient, (1, self.image_width))
return np.stack((image,) * 3, axis=-1).astype(np.uint8)
elif self.bg_source == BGSource.BOTTOM:
gradient = np.linspace(min_pix, max_pix, self.image_height)[:, None]
image = np.tile(gradient, (1, self.image_width))
return np.stack((image,) * 3, axis=-1).astype(np.uint8)
else:
raise ValueError('Wrong initial latent!')
@torch.no_grad()
def __call__(self, input_video, init_latent=None, input_strength=None):
input_latent = encode_video(self.vae, input_video)* self.vae.config.scaling_factor
if input_strength:
light_strength = input_strength
else:
light_strength = self.lowres_denoise
if not init_latent:
init_latent = self.bg_latent
latents = self.pipeline(
image=init_latent,
strength=light_strength,
prompt_embeds=self.conds,
negative_prompt_embeds=self.unconds,
width=self.image_width,
height=self.image_height,
num_inference_steps=int(round(self.steps / self.lowres_denoise)),
num_images_per_prompt=self.num_samples,
generator=self.generator,
output_type='latent',
guidance_scale=self.cfg,
cross_attention_kwargs={'concat_conds': input_latent},
).images.to(self.pipeline.vae.dtype)
relight_video = decode_latents(self.vae, latents)
return relight_video