Spaces:
Runtime error
Runtime error
| # Copyright 2024 MIT Han Lab | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import os | |
| import pathlib | |
| from typing import Any, Callable, Optional, Union | |
| import numpy as np | |
| from PIL import Image | |
| from torch.utils.data.dataset import Dataset | |
| from torchvision.datasets import ImageFolder | |
| __all__ = ["load_image", "load_image_from_dir", "DMCrop", "CustomImageFolder", "ImageDataset"] | |
| def load_image(data_path: str, mode="rgb") -> Image.Image: | |
| img = Image.open(data_path) | |
| if mode == "rgb": | |
| img = img.convert("RGB") | |
| return img | |
| def load_image_from_dir( | |
| dir_path: str, | |
| suffix: Union[str, tuple[str, ...], list[str]] = (".jpg", ".JPEG", ".png"), | |
| return_mode="path", | |
| k: Optional[int] = None, | |
| shuffle_func: Optional[Callable] = None, | |
| ) -> Union[list, tuple[list, list]]: | |
| suffix = [suffix] if isinstance(suffix, str) else suffix | |
| file_list = [] | |
| for dirpath, _, fnames in os.walk(dir_path): | |
| for fname in fnames: | |
| if pathlib.Path(fname).suffix not in suffix: | |
| continue | |
| image_path = os.path.join(dirpath, fname) | |
| file_list.append(image_path) | |
| if shuffle_func is not None and k is not None: | |
| shuffle_file_list = shuffle_func(file_list) | |
| file_list = shuffle_file_list or file_list | |
| file_list = file_list[:k] | |
| file_list = sorted(file_list) | |
| if return_mode == "path": | |
| return file_list | |
| else: | |
| files = [] | |
| path_list = [] | |
| for file_path in file_list: | |
| try: | |
| files.append(load_image(file_path)) | |
| path_list.append(file_path) | |
| except Exception: | |
| print(f"Fail to load {file_path}") | |
| if return_mode == "image": | |
| return files | |
| else: | |
| return path_list, files | |
| class DMCrop: | |
| """center/random crop used in diffusion models""" | |
| def __init__(self, size: int) -> None: | |
| self.size = size | |
| def __call__(self, pil_image: Image.Image) -> Image.Image: | |
| """ | |
| Center cropping implementation from ADM. | |
| https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 | |
| """ | |
| image_size = self.size | |
| if pil_image.size == (image_size, image_size): | |
| return pil_image | |
| while min(*pil_image.size) >= 2 * image_size: | |
| pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) | |
| scale = image_size / min(*pil_image.size) | |
| pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) | |
| arr = np.array(pil_image) | |
| crop_y = (arr.shape[0] - image_size) // 2 | |
| crop_x = (arr.shape[1] - image_size) // 2 | |
| return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) | |
| class CustomImageFolder(ImageFolder): | |
| def __init__(self, root: str, transform: Optional[Callable] = None, return_dict: bool = False): | |
| root = os.path.expanduser(root) | |
| self.return_dict = return_dict | |
| super().__init__(root, transform) | |
| def __getitem__(self, index: int) -> Union[dict[str, Any], tuple[Any, Any]]: | |
| path, target = self.samples[index] | |
| image = load_image(path) | |
| if self.transform is not None: | |
| image = self.transform(image) | |
| if self.return_dict: | |
| return { | |
| "index": index, | |
| "image_path": path, | |
| "image": image, | |
| "label": target, | |
| } | |
| else: | |
| return image, target | |
| class ImageDataset(Dataset): | |
| def __init__( | |
| self, | |
| data_dirs: Union[str, list[str]], | |
| splits: Optional[Union[str, list[Optional[str]]]] = None, | |
| transform: Optional[Callable] = None, | |
| suffix=(".jpg", ".JPEG", ".png"), | |
| pil=True, | |
| return_dict=True, | |
| ) -> None: | |
| super().__init__() | |
| self.data_dirs = [data_dirs] if isinstance(data_dirs, str) else data_dirs | |
| if isinstance(splits, list): | |
| assert len(splits) == len(self.data_dirs) | |
| self.splits = splits | |
| elif isinstance(splits, str): | |
| assert len(self.data_dirs) == 1 | |
| self.splits = [splits] | |
| else: | |
| self.splits = [None for _ in range(len(self.data_dirs))] | |
| self.transform = transform | |
| self.pil = pil | |
| self.return_dict = return_dict | |
| # load all images [image_path] | |
| self.samples = [] | |
| for data_dir, split in zip(self.data_dirs, self.splits): | |
| if split is None: | |
| samples = load_image_from_dir(data_dir, suffix, return_mode="path") | |
| else: | |
| samples = [] | |
| with open(split) as fin: | |
| for line in fin.readlines(): | |
| relative_path = line[:-1] | |
| full_path = os.path.join(data_dir, relative_path) | |
| samples.append(full_path) | |
| self.samples += samples | |
| def __len__(self) -> int: | |
| return len(self.samples) | |
| def __getitem__(self, index: int, skip_image=False) -> dict[str, Any]: | |
| image_path = self.samples[index] | |
| if skip_image: | |
| image = None | |
| else: | |
| try: | |
| image = load_image(image_path, return_pil=self.pil) | |
| except Exception: | |
| print(f"Fail to load {image_path}") | |
| raise OSError | |
| if self.transform is not None: | |
| image = self.transform(image) | |
| if self.return_dict: | |
| return { | |
| "index": index, | |
| "image_path": image_path, | |
| "image_name": os.path.basename(image_path), | |
| "data": image, | |
| } | |
| else: | |
| return image | |