Spaces:
Running
Running
| import copy | |
| import json | |
| import os | |
| import random | |
| import traceback | |
| from functools import lru_cache | |
| from typing import List, TYPE_CHECKING | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from PIL.ImageOps import exif_transpose | |
| from torchvision import transforms | |
| from torch.utils.data import Dataset, DataLoader, ConcatDataset | |
| from tqdm import tqdm | |
| import albumentations as A | |
| from toolkit.buckets import get_bucket_for_image_size, BucketResolution | |
| from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config | |
| from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin | |
| from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO | |
| from toolkit.print import print_acc | |
| from toolkit.accelerator import get_accelerator | |
| import platform | |
| def is_native_windows(): | |
| return platform.system() == "Windows" and platform.release() != "2" | |
| if TYPE_CHECKING: | |
| from toolkit.stable_diffusion_model import StableDiffusion | |
| image_extensions = ['.jpg', '.jpeg', '.png', '.webp'] | |
| video_extensions = ['.mp4', '.avi', '.mov', '.webm', '.mkv', '.wmv', '.m4v', '.flv'] | |
| class RescaleTransform: | |
| """Transform to rescale images to the range [-1, 1].""" | |
| def __call__(self, image): | |
| return image * 2 - 1 | |
| class NormalizeSDXLTransform: | |
| """ | |
| Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images | |
| Mean: tensor([ 0.0002, -0.1034, -0.1879]) | |
| Standard Deviation: tensor([0.5436, 0.5116, 0.5033]) | |
| """ | |
| def __call__(self, image): | |
| return transforms.Normalize( | |
| mean=[0.0002, -0.1034, -0.1879], | |
| std=[0.5436, 0.5116, 0.5033], | |
| )(image) | |
| class NormalizeSD15Transform: | |
| """ | |
| Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images | |
| Mean: tensor([-0.1600, -0.2450, -0.3227]) | |
| Standard Deviation: tensor([0.5319, 0.4997, 0.5139]) | |
| """ | |
| def __call__(self, image): | |
| return transforms.Normalize( | |
| mean=[-0.1600, -0.2450, -0.3227], | |
| std=[0.5319, 0.4997, 0.5139], | |
| )(image) | |
| class ImageDataset(Dataset, CaptionMixin): | |
| def __init__(self, config): | |
| self.config = config | |
| self.name = self.get_config('name', 'dataset') | |
| self.path = self.get_config('path', required=True) | |
| self.scale = self.get_config('scale', 1) | |
| self.random_scale = self.get_config('random_scale', False) | |
| self.include_prompt = self.get_config('include_prompt', False) | |
| self.default_prompt = self.get_config('default_prompt', '') | |
| if self.include_prompt: | |
| self.caption_type = self.get_config('caption_ext', 'txt') | |
| else: | |
| self.caption_type = None | |
| # we always random crop if random scale is enabled | |
| self.random_crop = self.random_scale if self.random_scale else self.get_config('random_crop', False) | |
| self.resolution = self.get_config('resolution', 256) | |
| self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if | |
| file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))] | |
| # this might take a while | |
| print_acc(f" - Preprocessing image dimensions") | |
| new_file_list = [] | |
| bad_count = 0 | |
| for file in tqdm(self.file_list): | |
| img = Image.open(file) | |
| if int(min(img.size) * self.scale) >= self.resolution: | |
| new_file_list.append(file) | |
| else: | |
| bad_count += 1 | |
| self.file_list = new_file_list | |
| print_acc(f" - Found {len(self.file_list)} images") | |
| print_acc(f" - Found {bad_count} images that are too small") | |
| assert len(self.file_list) > 0, f"no images found in {self.path}" | |
| self.transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| RescaleTransform(), | |
| ]) | |
| def get_config(self, key, default=None, required=False): | |
| if key in self.config: | |
| value = self.config[key] | |
| return value | |
| elif required: | |
| raise ValueError(f'config file error. Missing "config.dataset.{key}" key') | |
| else: | |
| return default | |
| def __len__(self): | |
| return len(self.file_list) | |
| def __getitem__(self, index): | |
| img_path = self.file_list[index] | |
| try: | |
| img = exif_transpose(Image.open(img_path)).convert('RGB') | |
| except Exception as e: | |
| print_acc(f"Error opening image: {img_path}") | |
| print_acc(e) | |
| # make a noise image if we can't open it | |
| img = Image.fromarray(np.random.randint(0, 255, (1024, 1024, 3), dtype=np.uint8)) | |
| # Downscale the source image first | |
| img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC) | |
| min_img_size = min(img.size) | |
| if self.random_crop: | |
| if self.random_scale and min_img_size > self.resolution: | |
| if min_img_size < self.resolution: | |
| print_acc( | |
| f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}") | |
| scale_size = self.resolution | |
| else: | |
| scale_size = random.randint(self.resolution, int(min_img_size)) | |
| scaler = scale_size / min_img_size | |
| scale_width = int((img.width + 5) * scaler) | |
| scale_height = int((img.height + 5) * scaler) | |
| img = img.resize((scale_width, scale_height), Image.BICUBIC) | |
| img = transforms.RandomCrop(self.resolution)(img) | |
| else: | |
| img = transforms.CenterCrop(min_img_size)(img) | |
| img = img.resize((self.resolution, self.resolution), Image.BICUBIC) | |
| img = self.transform(img) | |
| if self.include_prompt: | |
| prompt = self.get_caption_item(index) | |
| return img, prompt | |
| else: | |
| return img | |
| class AugmentedImageDataset(ImageDataset): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.augmentations = self.get_config('augmentations', []) | |
| self.augmentations = [Augments(**aug) for aug in self.augmentations] | |
| augmentation_list = [] | |
| for aug in self.augmentations: | |
| # make sure method name is valid | |
| assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}" | |
| # get the method | |
| method = getattr(A, aug.method_name) | |
| # add the method to the list | |
| augmentation_list.append(method(**aug.params)) | |
| self.aug_transform = A.Compose(augmentation_list) | |
| self.original_transform = self.transform | |
| # replace transform so we get raw pil image | |
| self.transform = transforms.Compose([]) | |
| def __getitem__(self, index): | |
| # get the original image | |
| # image is a PIL image, convert to bgr | |
| pil_image = super().__getitem__(index) | |
| open_cv_image = np.array(pil_image) | |
| # Convert RGB to BGR | |
| open_cv_image = open_cv_image[:, :, ::-1].copy() | |
| # apply augmentations | |
| augmented = self.aug_transform(image=open_cv_image)["image"] | |
| # convert back to RGB tensor | |
| augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB) | |
| # convert to PIL image | |
| augmented = Image.fromarray(augmented) | |
| # return both # return image as 0 - 1 tensor | |
| return transforms.ToTensor()(pil_image), transforms.ToTensor()(augmented) | |
| class PairedImageDataset(Dataset): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.size = self.get_config('size', 512) | |
| self.path = self.get_config('path', None) | |
| self.pos_folder = self.get_config('pos_folder', None) | |
| self.neg_folder = self.get_config('neg_folder', None) | |
| self.default_prompt = self.get_config('default_prompt', '') | |
| self.network_weight = self.get_config('network_weight', 1.0) | |
| self.pos_weight = self.get_config('pos_weight', self.network_weight) | |
| self.neg_weight = self.get_config('neg_weight', self.network_weight) | |
| supported_exts = ('.jpg', '.jpeg', '.png', '.webp', '.JPEG', '.JPG', '.PNG', '.WEBP') | |
| if self.pos_folder is not None and self.neg_folder is not None: | |
| # find matching files | |
| self.pos_file_list = [os.path.join(self.pos_folder, file) for file in os.listdir(self.pos_folder) if | |
| file.lower().endswith(supported_exts)] | |
| self.neg_file_list = [os.path.join(self.neg_folder, file) for file in os.listdir(self.neg_folder) if | |
| file.lower().endswith(supported_exts)] | |
| matched_files = [] | |
| for pos_file in self.pos_file_list: | |
| pos_file_no_ext = os.path.splitext(pos_file)[0] | |
| for neg_file in self.neg_file_list: | |
| neg_file_no_ext = os.path.splitext(neg_file)[0] | |
| if os.path.basename(pos_file_no_ext) == os.path.basename(neg_file_no_ext): | |
| matched_files.append((neg_file, pos_file)) | |
| break | |
| # remove duplicates | |
| matched_files = [t for t in (set(tuple(i) for i in matched_files))] | |
| self.file_list = matched_files | |
| print_acc(f" - Found {len(self.file_list)} matching pairs") | |
| else: | |
| self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if | |
| file.lower().endswith(supported_exts)] | |
| print_acc(f" - Found {len(self.file_list)} images") | |
| self.transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| RescaleTransform(), | |
| ]) | |
| def get_all_prompts(self): | |
| prompts = [] | |
| for index in range(len(self.file_list)): | |
| prompts.append(self.get_prompt_item(index)) | |
| # remove duplicates | |
| prompts = list(set(prompts)) | |
| return prompts | |
| def __len__(self): | |
| return len(self.file_list) | |
| def get_config(self, key, default=None, required=False): | |
| if key in self.config: | |
| value = self.config[key] | |
| return value | |
| elif required: | |
| raise ValueError(f'config file error. Missing "config.dataset.{key}" key') | |
| else: | |
| return default | |
| def get_prompt_item(self, index): | |
| img_path_or_tuple = self.file_list[index] | |
| if isinstance(img_path_or_tuple, tuple): | |
| # check if either has a prompt file | |
| path_no_ext = os.path.splitext(img_path_or_tuple[0])[0] | |
| prompt_path = path_no_ext + '.txt' | |
| if not os.path.exists(prompt_path): | |
| path_no_ext = os.path.splitext(img_path_or_tuple[1])[0] | |
| prompt_path = path_no_ext + '.txt' | |
| else: | |
| img_path = img_path_or_tuple | |
| # see if prompt file exists | |
| path_no_ext = os.path.splitext(img_path)[0] | |
| prompt_path = path_no_ext + '.txt' | |
| if os.path.exists(prompt_path): | |
| with open(prompt_path, 'r', encoding='utf-8') as f: | |
| prompt = f.read() | |
| # remove any newlines | |
| prompt = prompt.replace('\n', ', ') | |
| # remove new lines for all operating systems | |
| prompt = prompt.replace('\r', ', ') | |
| prompt_split = prompt.split(',') | |
| # remove empty strings | |
| prompt_split = [p.strip() for p in prompt_split if p.strip()] | |
| # join back together | |
| prompt = ', '.join(prompt_split) | |
| else: | |
| prompt = self.default_prompt | |
| return prompt | |
| def __getitem__(self, index): | |
| img_path_or_tuple = self.file_list[index] | |
| if isinstance(img_path_or_tuple, tuple): | |
| # load both images | |
| img_path = img_path_or_tuple[0] | |
| img1 = exif_transpose(Image.open(img_path)).convert('RGB') | |
| img_path = img_path_or_tuple[1] | |
| img2 = exif_transpose(Image.open(img_path)).convert('RGB') | |
| # always use # 2 (pos) | |
| bucket_resolution = get_bucket_for_image_size( | |
| width=img2.width, | |
| height=img2.height, | |
| resolution=self.size, | |
| # divisibility=self. | |
| ) | |
| # images will be same base dimension, but may be trimmed. We need to shrink and then central crop | |
| if bucket_resolution['width'] > bucket_resolution['height']: | |
| img1_scale_to_height = bucket_resolution["height"] | |
| img1_scale_to_width = int(img1.width * (bucket_resolution["height"] / img1.height)) | |
| img2_scale_to_height = bucket_resolution["height"] | |
| img2_scale_to_width = int(img2.width * (bucket_resolution["height"] / img2.height)) | |
| else: | |
| img1_scale_to_width = bucket_resolution["width"] | |
| img1_scale_to_height = int(img1.height * (bucket_resolution["width"] / img1.width)) | |
| img2_scale_to_width = bucket_resolution["width"] | |
| img2_scale_to_height = int(img2.height * (bucket_resolution["width"] / img2.width)) | |
| img1_crop_height = bucket_resolution["height"] | |
| img1_crop_width = bucket_resolution["width"] | |
| img2_crop_height = bucket_resolution["height"] | |
| img2_crop_width = bucket_resolution["width"] | |
| # scale then center crop images | |
| img1 = img1.resize((img1_scale_to_width, img1_scale_to_height), Image.BICUBIC) | |
| img1 = transforms.CenterCrop((img1_crop_height, img1_crop_width))(img1) | |
| img2 = img2.resize((img2_scale_to_width, img2_scale_to_height), Image.BICUBIC) | |
| img2 = transforms.CenterCrop((img2_crop_height, img2_crop_width))(img2) | |
| # combine them side by side | |
| img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height))) | |
| img.paste(img1, (0, 0)) | |
| img.paste(img2, (img1.width, 0)) | |
| else: | |
| img_path = img_path_or_tuple | |
| img = exif_transpose(Image.open(img_path)).convert('RGB') | |
| height = self.size | |
| # determine width to keep aspect ratio | |
| width = int(img.size[0] * height / img.size[1]) | |
| # Downscale the source image first | |
| img = img.resize((width, height), Image.BICUBIC) | |
| prompt = self.get_prompt_item(index) | |
| img = self.transform(img) | |
| return img, prompt, (self.neg_weight, self.pos_weight) | |
| class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset): | |
| def __init__( | |
| self, | |
| dataset_config: 'DatasetConfig', | |
| batch_size=1, | |
| sd: 'StableDiffusion' = None, | |
| ): | |
| self.dataset_config = dataset_config | |
| # update bucket divisibility | |
| self.dataset_config.bucket_tolerance = sd.get_bucket_divisibility() | |
| self.is_video = dataset_config.num_frames > 1 | |
| super().__init__() | |
| folder_path = dataset_config.folder_path | |
| self.dataset_path = dataset_config.dataset_path | |
| if self.dataset_path is None: | |
| self.dataset_path = folder_path | |
| self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk | |
| self.is_caching_latents_to_memory = dataset_config.cache_latents | |
| self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk | |
| self.is_caching_clip_vision_to_disk = dataset_config.cache_clip_vision_to_disk | |
| self.is_generating_controls = len(dataset_config.controls) > 0 | |
| self.epoch_num = 0 | |
| self.sd = sd | |
| if self.sd is None and self.is_caching_latents: | |
| raise ValueError(f"sd is required for caching latents") | |
| self.caption_type = dataset_config.caption_ext | |
| self.default_caption = dataset_config.default_caption | |
| self.random_scale = dataset_config.random_scale | |
| self.scale = dataset_config.scale | |
| self.batch_size = batch_size | |
| # we always random crop if random scale is enabled | |
| self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop | |
| self.resolution = dataset_config.resolution | |
| self.caption_dict = None | |
| self.file_list: List['FileItemDTO'] = [] | |
| # check if dataset_path is a folder or json | |
| if os.path.isdir(self.dataset_path): | |
| extensions = image_extensions | |
| if self.is_video: | |
| # only look for videos | |
| extensions = video_extensions | |
| file_list = [os.path.join(root, file) for root, _, files in os.walk(self.dataset_path) for file in files if file.lower().endswith(tuple(extensions))] | |
| else: | |
| # assume json | |
| with open(self.dataset_path, 'r') as f: | |
| self.caption_dict = json.load(f) | |
| # keys are file paths | |
| file_list = list(self.caption_dict.keys()) | |
| # remove items in the _controls_ folder | |
| file_list = [x for x in file_list if not os.path.basename(os.path.dirname(x)) == "_controls"] | |
| if self.dataset_config.num_repeats > 1: | |
| # repeat the list | |
| file_list = file_list * self.dataset_config.num_repeats | |
| if self.dataset_config.standardize_images: | |
| if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd: | |
| NormalizeMethod = NormalizeSDXLTransform | |
| else: | |
| NormalizeMethod = NormalizeSD15Transform | |
| self.transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| RescaleTransform(), | |
| NormalizeMethod(), | |
| ]) | |
| else: | |
| self.transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| RescaleTransform(), | |
| ]) | |
| # this might take a while | |
| print_acc(f"Dataset: {self.dataset_path}") | |
| if self.is_video: | |
| print_acc(f" - Preprocessing video dimensions") | |
| else: | |
| print_acc(f" - Preprocessing image dimensions") | |
| dataset_folder = self.dataset_path | |
| if not os.path.isdir(self.dataset_path): | |
| dataset_folder = os.path.dirname(dataset_folder) | |
| dataset_size_file = os.path.join(dataset_folder, '.aitk_size.json') | |
| dataloader_version = "0.1.2" | |
| if os.path.exists(dataset_size_file): | |
| try: | |
| with open(dataset_size_file, 'r') as f: | |
| self.size_database = json.load(f) | |
| if "__version__" not in self.size_database or self.size_database["__version__"] != dataloader_version: | |
| print_acc("Upgrading size database to new version") | |
| # old version, delete and recreate | |
| self.size_database = {} | |
| except Exception as e: | |
| print_acc(f"Error loading size database: {dataset_size_file}") | |
| print_acc(e) | |
| self.size_database = {} | |
| else: | |
| self.size_database = {} | |
| self.size_database["__version__"] = dataloader_version | |
| bad_count = 0 | |
| for file in tqdm(file_list): | |
| try: | |
| file_item = FileItemDTO( | |
| sd=self.sd, | |
| path=file, | |
| dataset_config=dataset_config, | |
| dataloader_transforms=self.transform, | |
| size_database=self.size_database, | |
| dataset_root=dataset_folder, | |
| ) | |
| self.file_list.append(file_item) | |
| except Exception as e: | |
| print_acc(traceback.format_exc()) | |
| if self.is_video: | |
| print_acc(f"Error processing video: {file}") | |
| else: | |
| print_acc(f"Error processing image: {file}") | |
| print_acc(e) | |
| bad_count += 1 | |
| # save the size database | |
| with open(dataset_size_file, 'w') as f: | |
| json.dump(self.size_database, f) | |
| if self.is_video: | |
| print_acc(f" - Found {len(self.file_list)} videos") | |
| assert len(self.file_list) > 0, f"no videos found in {self.dataset_path}" | |
| else: | |
| print_acc(f" - Found {len(self.file_list)} images") | |
| assert len(self.file_list) > 0, f"no images found in {self.dataset_path}" | |
| # handle x axis flips | |
| if self.dataset_config.flip_x: | |
| print_acc(" - adding x axis flips") | |
| current_file_list = [x for x in self.file_list] | |
| for file_item in current_file_list: | |
| # create a copy that is flipped on the x axis | |
| new_file_item = copy.deepcopy(file_item) | |
| new_file_item.flip_x = True | |
| self.file_list.append(new_file_item) | |
| # handle y axis flips | |
| if self.dataset_config.flip_y: | |
| print_acc(" - adding y axis flips") | |
| current_file_list = [x for x in self.file_list] | |
| for file_item in current_file_list: | |
| # create a copy that is flipped on the y axis | |
| new_file_item = copy.deepcopy(file_item) | |
| new_file_item.flip_y = True | |
| self.file_list.append(new_file_item) | |
| if self.dataset_config.flip_x or self.dataset_config.flip_y: | |
| if self.is_video: | |
| print_acc(f" - Found {len(self.file_list)} videos after adding flips") | |
| else: | |
| print_acc(f" - Found {len(self.file_list)} images after adding flips") | |
| self.setup_epoch() | |
| def setup_epoch(self): | |
| if self.epoch_num == 0: | |
| # initial setup | |
| # do not call for now | |
| if self.dataset_config.buckets: | |
| # setup buckets | |
| self.setup_buckets() | |
| if self.is_caching_latents: | |
| self.cache_latents_all_latents() | |
| if self.is_caching_clip_vision_to_disk: | |
| self.cache_clip_vision_to_disk() | |
| if self.is_generating_controls: | |
| # always do this last | |
| self.setup_controls() | |
| else: | |
| if self.dataset_config.poi is not None: | |
| # handle cropping to a specific point of interest | |
| # setup buckets every epoch | |
| self.setup_buckets(quiet=True) | |
| self.epoch_num += 1 | |
| def __len__(self): | |
| if self.dataset_config.buckets: | |
| return len(self.batch_indices) | |
| return len(self.file_list) | |
| def _get_single_item(self, index) -> 'FileItemDTO': | |
| file_item: 'FileItemDTO' = copy.deepcopy(self.file_list[index]) | |
| file_item.load_and_process_image(self.transform) | |
| file_item.load_caption(self.caption_dict) | |
| return file_item | |
| def __getitem__(self, item): | |
| if self.dataset_config.buckets: | |
| # for buckets we collate ourselves for now | |
| # todo allow a scheduler to dynamically make buckets | |
| # we collate ourselves | |
| if len(self.batch_indices) - 1 < item: | |
| # tried everything to solve this. No way to reset length when redoing things. Pick another index | |
| item = random.randint(0, len(self.batch_indices) - 1) | |
| idx_list = self.batch_indices[item] | |
| return [self._get_single_item(idx) for idx in idx_list] | |
| else: | |
| # Dataloader is batching | |
| return self._get_single_item(item) | |
| def get_dataloader_from_datasets( | |
| dataset_options, | |
| batch_size=1, | |
| sd: 'StableDiffusion' = None, | |
| ) -> DataLoader: | |
| if dataset_options is None or len(dataset_options) == 0: | |
| return None | |
| datasets = [] | |
| has_buckets = False | |
| is_caching_latents = False | |
| dataset_config_list = [] | |
| # preprocess them all | |
| for dataset_option in dataset_options: | |
| if isinstance(dataset_option, DatasetConfig): | |
| dataset_config_list.append(dataset_option) | |
| else: | |
| # preprocess raw data | |
| split_configs = preprocess_dataset_raw_config([dataset_option]) | |
| for x in split_configs: | |
| dataset_config_list.append(DatasetConfig(**x)) | |
| for config in dataset_config_list: | |
| if config.type == 'image': | |
| dataset = AiToolkitDataset(config, batch_size=batch_size, sd=sd) | |
| datasets.append(dataset) | |
| if config.buckets: | |
| has_buckets = True | |
| if config.cache_latents or config.cache_latents_to_disk: | |
| is_caching_latents = True | |
| else: | |
| raise ValueError(f"invalid dataset type: {config.type}") | |
| concatenated_dataset = ConcatDataset(datasets) | |
| # todo build scheduler that can get buckets from all datasets that match | |
| # todo and evenly distribute reg images | |
| def dto_collation(batch: List['FileItemDTO']): | |
| # create DTO batch | |
| batch = DataLoaderBatchDTO( | |
| file_items=batch | |
| ) | |
| return batch | |
| # check if is caching latents | |
| dataloader_kwargs = {} | |
| if is_native_windows(): | |
| dataloader_kwargs['num_workers'] = 0 | |
| else: | |
| dataloader_kwargs['num_workers'] = dataset_config_list[0].num_workers | |
| dataloader_kwargs['prefetch_factor'] = dataset_config_list[0].prefetch_factor | |
| if has_buckets: | |
| # make sure they all have buckets | |
| for dataset in datasets: | |
| assert dataset.dataset_config.buckets, f"buckets not found on dataset {dataset.dataset_config.folder_path}, you either need all buckets or none" | |
| data_loader = DataLoader( | |
| concatenated_dataset, | |
| batch_size=None, # we batch in the datasets for now | |
| drop_last=False, | |
| shuffle=True, | |
| collate_fn=dto_collation, # Use the custom collate function | |
| **dataloader_kwargs | |
| ) | |
| else: | |
| data_loader = DataLoader( | |
| concatenated_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| collate_fn=dto_collation, | |
| **dataloader_kwargs | |
| ) | |
| return data_loader | |
| def trigger_dataloader_setup_epoch(dataloader: DataLoader): | |
| # hacky but needed because of different types of datasets and dataloaders | |
| dataloader.len = None | |
| if isinstance(dataloader.dataset, list): | |
| for dataset in dataloader.dataset: | |
| if hasattr(dataset, 'datasets'): | |
| for sub_dataset in dataset.datasets: | |
| if hasattr(sub_dataset, 'setup_epoch'): | |
| sub_dataset.setup_epoch() | |
| sub_dataset.len = None | |
| elif hasattr(dataset, 'setup_epoch'): | |
| dataset.setup_epoch() | |
| dataset.len = None | |
| elif hasattr(dataloader.dataset, 'setup_epoch'): | |
| dataloader.dataset.setup_epoch() | |
| dataloader.dataset.len = None | |
| elif hasattr(dataloader.dataset, 'datasets'): | |
| dataloader.dataset.len = None | |
| for sub_dataset in dataloader.dataset.datasets: | |
| if hasattr(sub_dataset, 'setup_epoch'): | |
| sub_dataset.setup_epoch() | |
| sub_dataset.len = None | |
| def get_dataloader_datasets(dataloader: DataLoader): | |
| # hacky but needed because of different types of datasets and dataloaders | |
| if isinstance(dataloader.dataset, list): | |
| datasets = [] | |
| for dataset in dataloader.dataset: | |
| if hasattr(dataset, 'datasets'): | |
| for sub_dataset in dataset.datasets: | |
| datasets.append(sub_dataset) | |
| else: | |
| datasets.append(dataset) | |
| return datasets | |
| elif hasattr(dataloader.dataset, 'datasets'): | |
| return dataloader.dataset.datasets | |
| else: | |
| return [dataloader.dataset] | |