Spaces:
Running
on
Zero
Running
on
Zero
# dataset_and_utils.py - Optimized and Improved Version | |
import os | |
from typing import Dict, List, Optional, Tuple | |
import numpy as np | |
import pandas as pd | |
import PIL | |
import torch | |
import torch.utils.checkpoint | |
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | |
from PIL import Image | |
from safetensors import safe_open | |
from safetensors.torch import save_file | |
from torch.utils.data import Dataset | |
from transformers import AutoTokenizer, PretrainedConfig | |
def prepare_image(image: PIL.Image.Image, width: int = 512, height: int = 512) -> torch.Tensor: | |
""" | |
Prepares an image for model input by resizing and normalizing it. | |
""" | |
image = image.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) | |
arr = np.array(image.convert("RGB"), dtype=np.float32) / 127.5 - 1 | |
return torch.from_numpy(np.transpose(arr, (2, 0, 1))).unsqueeze(0) | |
def prepare_mask(mask: PIL.Image.Image, width: int = 512, height: int = 512) -> torch.Tensor: | |
""" | |
Prepares a mask image for model input by resizing and normalizing it. | |
""" | |
mask = mask.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) | |
arr = np.array(mask.convert("L"), dtype=np.float32) / 255.0 | |
return torch.from_numpy(np.expand_dims(arr, 0)).unsqueeze(0) | |
class PreprocessedDataset(Dataset): | |
def __init__( | |
self, | |
csv_path: str, | |
tokenizer_1, | |
tokenizer_2, | |
vae_encoder, | |
text_encoder_1=None, | |
text_encoder_2=None, | |
do_cache: bool = False, | |
size: int = 512, | |
text_dropout: float = 0.0, | |
scale_vae_latents: bool = True, | |
substitute_caption_map: Dict[str, str] = None, | |
): | |
""" | |
Dataset class that pre-processes images, masks, and text data for training. | |
""" | |
super().__init__() | |
self.data = pd.read_csv(csv_path) | |
self.size = size | |
self.scale_vae_latents = scale_vae_latents | |
self.text_dropout = text_dropout | |
self.csv_path = csv_path | |
self.tokenizer_1 = tokenizer_1 | |
self.tokenizer_2 = tokenizer_2 | |
self.vae_encoder = vae_encoder | |
self.do_cache = do_cache | |
self.caption = self.data["caption"].str.lower() | |
if substitute_caption_map: | |
for key, value in substitute_caption_map.items(): | |
self.caption = self.caption.str.replace(key.lower(), value) | |
self.image_path = self.data["image_path"] | |
self.mask_path = self.data["mask_path"] if "mask_path" in self.data.columns else None | |
if text_encoder_1: | |
self.text_encoder_1 = text_encoder_1 | |
self.text_encoder_2 = text_encoder_2 | |
self.return_text_embeddings = True | |
raise NotImplementedError("Preprocessing for text encoder is not implemented yet.") | |
else: | |
self.return_text_embeddings = False | |
if self.do_cache: | |
self.vae_latents = [] | |
self.tokens_tuple = [] | |
self.masks = [] | |
print("Caching dataset...") | |
for idx in range(len(self.data)): | |
token, vae_latent, mask = self._process(idx) | |
self.tokens_tuple.append(token) | |
self.vae_latents.append(vae_latent) | |
self.masks.append(mask) | |
del self.vae_encoder # Free up memory | |
def _process(self, idx: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: | |
""" | |
Internal function to process images, text, and masks for a given index. | |
""" | |
image_path = os.path.join(os.path.dirname(self.csv_path), self.image_path[idx]) | |
image = prepare_image(Image.open(image_path).convert("RGB"), self.size, self.size).to( | |
dtype=self.vae_encoder.dtype, device=self.vae_encoder.device | |
) | |
caption = self.caption[idx] | |
ti1 = self.tokenizer_1(caption, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids | |
ti2 = self.tokenizer_2(caption, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids | |
vae_latent = self.vae_encoder.encode(image).latent_dist.sample() | |
if self.scale_vae_latents: | |
vae_latent *= self.vae_encoder.config.scaling_factor | |
if self.mask_path is None: | |
mask = torch.ones_like(vae_latent, dtype=self.vae_encoder.dtype, device=self.vae_encoder.device) | |
else: | |
mask_path = os.path.join(os.path.dirname(self.csv_path), self.mask_path[idx]) | |
mask = prepare_mask(Image.open(mask_path), self.size, self.size).to( | |
dtype=self.vae_encoder.dtype, device=self.vae_encoder.device | |
) | |
mask = torch.nn.functional.interpolate(mask, size=(vae_latent.shape[-2], vae_latent.shape[-1]), mode="nearest") | |
mask = mask.repeat(1, vae_latent.shape[1], 1, 1) | |
assert mask.shape == vae_latent.shape, "Mask and latent dimensions must match." | |
return (ti1.squeeze(), ti2.squeeze()), vae_latent.squeeze(), mask.squeeze() | |
def __len__(self) -> int: | |
return len(self.data) | |
def __getitem__(self, idx: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: | |
return self.atidx(idx) | |
def atidx(self, idx: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: | |
return self._process(idx) if not self.do_cache else (self.tokens_tuple[idx], self.vae_latents[idx], self.masks[idx]) | |
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"): | |
""" | |
Dynamically imports a model class based on configuration. | |
""" | |
config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, revision=revision) | |
model_class = config.architectures[0] | |
if model_class == "CLIPTextModel": | |
from transformers import CLIPTextModel | |
return CLIPTextModel | |
elif model_class == "CLIPTextModelWithProjection": | |
from transformers import CLIPTextModelWithProjection | |
return CLIPTextModelWithProjection | |
else: | |
raise ValueError(f"Unsupported model class: {model_class}") | |
def load_models(pretrained_model_name_or_path, revision, device, weight_dtype): | |
""" | |
Loads required models from a given pretrained path. | |
""" | |
tokenizer_1 = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", revision=revision, use_fast=False) | |
tokenizer_2 = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2", revision=revision, use_fast=False) | |
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") | |
text_encoder_cls_one = import_model_class_from_model_name_or_path(pretrained_model_name_or_path, revision) | |
text_encoder_cls_two = import_model_class_from_model_name_or_path(pretrained_model_name_or_path, revision, subfolder="text_encoder_2") | |
text_encoder_1 = text_encoder_cls_one.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", revision=revision) | |
text_encoder_2 = text_encoder_cls_two.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder_2", revision=revision) | |
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision) | |
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", revision=revision) | |
for model in [vae, text_encoder_1, text_encoder_2]: | |
model.requires_grad_(False) | |
model.to(device, dtype=weight_dtype) | |
unet.to(device, dtype=weight_dtype) | |
return tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet | |