from PIL import Image, ImageFilter, ImageDraw import cv2 import numpy as np from torch.utils.data import Dataset import torchvision.transforms as T import random class Subject200KDataset(Dataset): def __init__( self, base_dataset, condition_size: int = 512, target_size: int = 512, image_size: int = 512, padding: int = 0, condition_type: str = "subject", drop_text_prob: float = 0.1, drop_image_prob: float = 0.1, return_pil_image: bool = False, ): self.base_dataset = base_dataset self.condition_size = condition_size self.target_size = target_size self.image_size = image_size self.padding = padding self.condition_type = condition_type self.drop_text_prob = drop_text_prob self.drop_image_prob = drop_image_prob self.return_pil_image = return_pil_image self.to_tensor = T.ToTensor() def __len__(self): return len(self.base_dataset) * 2 def __getitem__(self, idx): # If target is 0, left image is target, right image is condition target = idx % 2 item = self.base_dataset[idx // 2] # Crop the image to target and condition image = item["image"] left_img = image.crop( ( self.padding, self.padding, self.image_size + self.padding, self.image_size + self.padding, ) ) right_img = image.crop( ( self.image_size + self.padding * 2, self.padding, self.image_size * 2 + self.padding * 2, self.image_size + self.padding, ) ) # Get the target and condition image target_image, condition_img = ( (left_img, right_img) if target == 0 else (right_img, left_img) ) # Resize the image condition_img = condition_img.resize( (self.condition_size, self.condition_size) ).convert("RGB") target_image = target_image.resize( (self.target_size, self.target_size) ).convert("RGB") # Get the description description = item["description"][ "description_0" if target == 0 else "description_1" ] # Randomly drop text or image drop_text = random.random() < self.drop_text_prob drop_image = random.random() < self.drop_image_prob if drop_text: description = "" if drop_image: condition_img = Image.new( "RGB", (self.condition_size, self.condition_size), (0, 0, 0) ) return { "image": self.to_tensor(target_image), "condition": self.to_tensor(condition_img), "condition_type": self.condition_type, "description": description, # 16 is the downscale factor of the image "position_delta": np.array([0, -self.condition_size // 16]), **({"pil_image": image} if self.return_pil_image else {}), } class ImageConditionDataset(Dataset): def __init__( self, base_dataset, condition_size: int = 512, target_size: int = 512, condition_type: str = "canny", drop_text_prob: float = 0.1, drop_image_prob: float = 0.1, return_pil_image: bool = False, position_scale=1.0, ): self.base_dataset = base_dataset self.condition_size = condition_size self.target_size = target_size self.condition_type = condition_type self.drop_text_prob = drop_text_prob self.drop_image_prob = drop_image_prob self.return_pil_image = return_pil_image self.position_scale = position_scale self.to_tensor = T.ToTensor() def __len__(self): return len(self.base_dataset) @property def depth_pipe(self): if not hasattr(self, "_depth_pipe"): from transformers import pipeline self._depth_pipe = pipeline( task="depth-estimation", model="LiheYoung/depth-anything-small-hf", device="cpu", ) return self._depth_pipe def _get_canny_edge(self, img): resize_ratio = self.condition_size / max(img.size) img = img.resize( (int(img.size[0] * resize_ratio), int(img.size[1] * resize_ratio)) ) img_np = np.array(img) img_gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) edges = cv2.Canny(img_gray, 100, 200) return Image.fromarray(edges).convert("RGB") def __getitem__(self, idx): image = self.base_dataset[idx]["jpg"] image = image.resize((self.target_size, self.target_size)).convert("RGB") description = self.base_dataset[idx]["json"]["prompt"] enable_scale = random.random() < 1 if not enable_scale: condition_size = int(self.condition_size * self.position_scale) position_scale = 1.0 else: condition_size = self.condition_size position_scale = self.position_scale # Get the condition image position_delta = np.array([0, 0]) if self.condition_type == "canny": condition_img = self._get_canny_edge(image) elif self.condition_type == "coloring": condition_img = ( image.resize((condition_size, condition_size)) .convert("L") .convert("RGB") ) elif self.condition_type == "deblurring": blur_radius = random.randint(1, 10) condition_img = ( image.convert("RGB") .filter(ImageFilter.GaussianBlur(blur_radius)) .resize((condition_size, condition_size)) .convert("RGB") ) elif self.condition_type == "depth": condition_img = self.depth_pipe(image)["depth"].convert("RGB") condition_img = condition_img.resize((condition_size, condition_size)) elif self.condition_type == "depth_pred": condition_img = image image = self.depth_pipe(condition_img)["depth"].convert("RGB") description = f"[depth] {description}" elif self.condition_type == "fill": condition_img = image.resize((condition_size, condition_size)).convert( "RGB" ) w, h = image.size x1, x2 = sorted([random.randint(0, w), random.randint(0, w)]) y1, y2 = sorted([random.randint(0, h), random.randint(0, h)]) mask = Image.new("L", image.size, 0) draw = ImageDraw.Draw(mask) draw.rectangle([x1, y1, x2, y2], fill=255) if random.random() > 0.5: mask = Image.eval(mask, lambda a: 255 - a) condition_img = Image.composite( image, Image.new("RGB", image.size, (0, 0, 0)), mask ) elif self.condition_type == "sr": condition_img = image.resize((condition_size, condition_size)).convert( "RGB" ) position_delta = np.array([0, -condition_size // 16]) else: raise ValueError(f"Condition type {self.condition_type} not implemented") # Randomly drop text or image drop_text = random.random() < self.drop_text_prob drop_image = random.random() < self.drop_image_prob if drop_text: description = "" if drop_image: condition_img = Image.new( "RGB", (condition_size, condition_size), (0, 0, 0) ) return { "image": self.to_tensor(image), "condition": self.to_tensor(condition_img), "condition_type": self.condition_type, "description": description, "position_delta": position_delta, **({"pil_image": [image, condition_img]} if self.return_pil_image else {}), **({"position_scale": position_scale} if position_scale != 1.0 else {}), } class CartoonDataset(Dataset): def __init__( self, base_dataset, condition_size: int = 1024, target_size: int = 1024, image_size: int = 1024, padding: int = 0, condition_type: str = "cartoon", drop_text_prob: float = 0.1, drop_image_prob: float = 0.1, return_pil_image: bool = False, ): self.base_dataset = base_dataset self.condition_size = condition_size self.target_size = target_size self.image_size = image_size self.padding = padding self.condition_type = condition_type self.drop_text_prob = drop_text_prob self.drop_image_prob = drop_image_prob self.return_pil_image = return_pil_image self.to_tensor = T.ToTensor() def __len__(self): return len(self.base_dataset) def __getitem__(self, idx): data = self.base_dataset[idx] condition_img = data["condition"] target_image = data["target"] # Tag tag = data["tags"][0] target_description = data["target_description"] description = { "lion": "lion like animal", "bear": "bear like animal", "gorilla": "gorilla like animal", "dog": "dog like animal", "elephant": "elephant like animal", "eagle": "eagle like bird", "tiger": "tiger like animal", "owl": "owl like bird", "woman": "woman", "parrot": "parrot like bird", "mouse": "mouse like animal", "man": "man", "pigeon": "pigeon like bird", "girl": "girl", "panda": "panda like animal", "crocodile": "crocodile like animal", "rabbit": "rabbit like animal", "boy": "boy", "monkey": "monkey like animal", "cat": "cat like animal", } # Resize the image condition_img = condition_img.resize( (self.condition_size, self.condition_size) ).convert("RGB") target_image = target_image.resize( (self.target_size, self.target_size) ).convert("RGB") # Process datum to create description description = data.get( "description", f"Photo of a {description[tag]} cartoon character in a white background. Character is facing {target_description['facing_direction']}. Character pose is {target_description['pose']}.", ) # Randomly drop text or image drop_text = random.random() < self.drop_text_prob drop_image = random.random() < self.drop_image_prob if drop_text: description = "" if drop_image: condition_img = Image.new( "RGB", (self.condition_size, self.condition_size), (0, 0, 0) ) return { "image": self.to_tensor(target_image), "condition": self.to_tensor(condition_img), "condition_type": self.condition_type, "description": description, # 16 is the downscale factor of the image "position_delta": np.array([0, -16]), }