Spaces:
Configuration error
Configuration error
import os | |
import numpy as np | |
from typing import List, Union | |
import PIL | |
import torch | |
import torch.utils.data | |
import torch.utils.checkpoint | |
from diffusers.pipeline_utils import DiffusionPipeline | |
from tqdm.auto import tqdm | |
from video_diffusion.common.image_util import make_grid, annotate_image | |
from video_diffusion.common.image_util import save_gif_mp4_folder_type | |
import cv2 | |
class SampleLogger: | |
def __init__( | |
self, | |
editing_prompts: List[str], | |
clip_length: int, | |
logdir: str, | |
subdir: str = "sample", | |
num_samples_per_prompt: int = 1, | |
sample_seeds: List[int] = None, | |
num_inference_steps: int = 20, | |
guidance_scale: float = 7, | |
strength: float = None, | |
annotate: bool = False, | |
annotate_size: int = 15, | |
make_grid: bool = True, | |
grid_column_size: int = 2, | |
layout_mask_dir: str = None, # New parameter for the layout mask directory | |
layouts_masks_orders: List[str]=None, | |
stride: int = 1, | |
n_sample_frame: int = 8, | |
start_sample_frame: int = None, | |
sampling_rate: int = 1, | |
**args | |
) -> None: | |
self.editing_prompts = editing_prompts | |
self.clip_length = clip_length | |
self.guidance_scale = guidance_scale | |
self.num_inference_steps = num_inference_steps | |
self.strength = strength | |
if sample_seeds is None: | |
max_num_samples_per_prompt = int(1e5) | |
if num_samples_per_prompt > max_num_samples_per_prompt: | |
raise ValueError | |
sample_seeds = torch.randint(0, max_num_samples_per_prompt, (num_samples_per_prompt,)) | |
sample_seeds = sorted(sample_seeds.numpy().tolist()) | |
self.sample_seeds = sample_seeds | |
self.logdir = os.path.join(logdir, subdir) | |
os.makedirs(self.logdir, exist_ok=True) | |
self.annotate = annotate | |
self.annotate_size = annotate_size | |
self.make_grid = make_grid | |
self.grid_column_size = grid_column_size | |
self.layout_mask_dir = layout_mask_dir # Initialize layout_mask_dir | |
self.layout_mask_orders = layouts_masks_orders | |
self.stride = stride | |
self.n_sample_frame = n_sample_frame | |
self.start_sample_frame = start_sample_frame | |
self.sampling_rate = sampling_rate | |
def _read_mask(self, mask_path, index: int, dest_size=(64, 64)): | |
mask_path = os.path.join(mask_path, f"{index:05d}.png") | |
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
mask = (mask > 0).astype(np.uint8) | |
mask = cv2.resize(mask, dest_size, interpolation=cv2.INTER_NEAREST) | |
mask = mask[np.newaxis, ...] | |
return mask | |
def get_frame_indices(self, index): | |
if self.start_sample_frame is not None: | |
frame_start = self.start_sample_frame + self.stride * index | |
else: | |
frame_start = self.stride * index | |
return (frame_start + i * self.sampling_rate for i in range(self.n_sample_frame)) | |
def read_layout_and_merge_masks(self, index): | |
layouts_all, masks_all = [],[] | |
for idx,layout_mask_order_per in enumerate(self.layout_mask_orders): | |
layout_ = [] | |
for layout_name in layout_mask_order_per: # Loop over prompts | |
frame_indices = self.get_frame_indices(index % self.clip_length) | |
layout_mask_dir = os.path.join(self.layout_mask_dir, layout_name) | |
mask = [self._read_mask(layout_mask_dir, i) for i in frame_indices] | |
masks = np.stack(mask) | |
layout_.append(masks) | |
layout_ = np.stack(layout_) | |
merged_masks = [] | |
for i in range(int(self.n_sample_frame)): | |
merged_mask_frame = np.sum(layout_[:, i, :, :], axis=0) | |
merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8) | |
merged_masks.append(merged_mask_frame) | |
masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w") | |
masks = torch.from_numpy(masks).half() | |
layouts = rearrange(layout_, "s f c h w -> f s c h w") | |
layouts = torch.from_numpy(layouts).half() | |
layouts_all.append(layouts) | |
masks_all.append(mask) | |
return masks_all, layouts_all | |
def log_sample_images( | |
self, pipeline: DiffusionPipeline, | |
device: torch.device, step: int, | |
image: Union[torch.FloatTensor, PIL.Image.Image] = None, | |
masks: Union[torch.FloatTensor, PIL.Image.Image] = None, | |
layouts : Union[torch.FloatTensor, PIL.Image.Image] = None, | |
latents: torch.FloatTensor = None, | |
control: torch.FloatTensor = None, | |
controlnet_conditioning_scale = None, | |
negative_prompt: Union[str, List[str]] = None, | |
blending_percentage = None, | |
trajs = None, | |
flatten_res = None, | |
source_prompt = None, | |
inject_step = None, | |
old_qk = None, | |
use_pnp = None, | |
cluster_inversion_feature = None, | |
vis_cross_attn = None, | |
attn_inversion_dict = None, | |
): | |
torch.cuda.empty_cache() | |
samples_all = [] | |
attention_all = [] | |
# handle input image | |
if image is not None: | |
input_pil_images = pipeline.numpy_to_pil(tensor_to_numpy(image))[0] | |
samples_all.append(input_pil_images) | |
# samples_all.append([ | |
# annotate_image(image, "input sequence", font_size=self.annotate_size) for image in input_pil_images | |
# ]) | |
#masks_all, layouts_all = self.read_layout_and_merge_masks() | |
#for idx, (prompt, masks, layouts) in enumerate(tqdm(zip(self.editing_prompts, masks_all, layouts_all), desc="Generating sample images")): | |
for idx, prompt in enumerate(tqdm(self.editing_prompts, desc="Generating sample images")): | |
for seed in self.sample_seeds: | |
generator = torch.Generator(device=device) | |
generator.manual_seed(seed) | |
sequence_return = pipeline( | |
prompt=prompt, | |
image=image, # torch.Size([8, 3, 512, 512]) | |
latent_mask=masks, | |
layouts = layouts, | |
strength=self.strength, | |
generator=generator, | |
num_inference_steps=self.num_inference_steps, | |
clip_length=self.clip_length, | |
guidance_scale=self.guidance_scale, | |
num_images_per_prompt=1, | |
# used in null inversion | |
control = control, | |
controlnet_conditioning_scale = controlnet_conditioning_scale, | |
latents = latents, | |
#uncond_embeddings_list = uncond_embeddings_list, | |
blending_percentage = blending_percentage, | |
logdir = self.logdir, | |
trajs = trajs, | |
flatten_res = flatten_res, | |
negative_prompt=negative_prompt, | |
source_prompt=source_prompt, | |
inject_step=inject_step, | |
old_qk=old_qk, | |
use_pnp=use_pnp, | |
cluster_inversion_feature= cluster_inversion_feature, | |
vis_cross_attn = vis_cross_attn, | |
attn_inversion_dict=attn_inversion_dict, | |
) | |
sequence = sequence_return.images[0] | |
torch.cuda.empty_cache() | |
if self.annotate: | |
images = [ | |
annotate_image(image, prompt, font_size=self.annotate_size) for image in sequence | |
] | |
else: | |
images = sequence | |
if self.make_grid: | |
samples_all.append(images) | |
save_path = os.path.join(self.logdir, f"step_{step}_{idx}_{seed}.gif") | |
save_gif_mp4_folder_type(images, save_path) | |
if self.make_grid: | |
samples_all = [make_grid(images, cols=int(np.ceil(np.sqrt(len(samples_all))))) for images in zip(*samples_all)] | |
save_path = os.path.join(self.logdir, f"step_{step}.gif") | |
save_gif_mp4_folder_type(samples_all, save_path) | |
return samples_all | |
from einops import rearrange | |
def tensor_to_numpy(image, b=1): | |
image = (image / 2 + 0.5).clamp(0, 1) | |
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | |
image = image.cpu().float().numpy() | |
image = rearrange(image, "(b f) c h w -> b f h w c", b=b) | |
return image |