Spaces:
Running
Running
import pathlib | |
from typing import Callable, Optional, Any, Tuple | |
import numpy as np | |
import pandas as pd | |
from PIL import Image | |
from torchvision.datasets import VisionDataset | |
from torchvision.datasets.utils import download_and_extract_archive, download_url | |
class StanfordCarsClass(VisionDataset): | |
"""`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset | |
The Cars dataset contains 16,185 images of 196 classes of cars. The data is | |
split into 8,144 training images and 8,041 testing images, where each class | |
has been split roughly in a 50-50 split | |
.. note:: | |
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format. | |
Args: | |
root (string): Root directory of dataset | |
split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. | |
transform (callable, optional): A function/transform that takes in an PIL image | |
and returns a transformed version. E.g, ``transforms.RandomCrop`` | |
target_transform (callable, optional): A function/transform that takes in the | |
target and transforms it. | |
download (bool, optional): If True, downloads the dataset from the internet and | |
puts it in root directory. If dataset is already downloaded, it is not | |
downloaded again.""" | |
root = pathlib.Path.home() / "tmp" / "Datasets" / "StanfordCars" | |
def __init__( | |
self, | |
train: bool = True, | |
transform: Optional[Callable] = None, | |
target_transform: Optional[Callable] = None, | |
download: bool = True, | |
) -> None: | |
try: | |
import scipy.io as sio | |
except ImportError: | |
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") | |
super().__init__(self.root, transform=transform, target_transform=target_transform) | |
self.train = train | |
self._base_folder = pathlib.Path(self.root) / "stanford_cars" | |
devkit = self._base_folder / "devkit" | |
if train: | |
self._annotations_mat_path = devkit / "cars_train_annos.mat" | |
self._images_base_path = self._base_folder / "cars_train" | |
else: | |
self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat" | |
self._images_base_path = self._base_folder / "cars_test" | |
if download: | |
self.download() | |
if not self._check_exists(): | |
raise RuntimeError("Dataset not found. You can use download=True to download it") | |
self.samples = [ | |
( | |
str(self._images_base_path / annotation["fname"]), | |
annotation["class"] - 1, # Original target mapping starts from 1, hence -1 | |
) | |
for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] | |
] | |
self.targets = np.array([x[1] for x in self.samples]) | |
self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() | |
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} | |
def __len__(self) -> int: | |
return len(self.samples) | |
def __getitem__(self, idx: int) -> Tuple[Any, Any]: | |
"""Returns pil_image and class_id for given index""" | |
image_path, target = self.samples[idx] | |
pil_image = Image.open(image_path).convert("RGB") | |
if self.transform is not None: | |
pil_image = self.transform(pil_image) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
return pil_image, target | |
def download(self) -> None: | |
if self._check_exists(): | |
return | |
download_and_extract_archive( | |
url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz", | |
download_root=str(self._base_folder), | |
md5="c3b158d763b6e2245038c8ad08e45376", | |
) | |
if self.train: | |
download_and_extract_archive( | |
url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz", | |
download_root=str(self._base_folder), | |
md5="065e5b463ae28d29e77c1b4b166cfe61", | |
) | |
else: | |
download_and_extract_archive( | |
url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz", | |
download_root=str(self._base_folder), | |
md5="4ce7ebf6a94d07f1952d94dd34c4d501", | |
) | |
download_url( | |
url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat", | |
root=str(self._base_folder), | |
md5="b0a2b23655a3edd16d84508592a98d10", | |
) | |
def _check_exists(self) -> bool: | |
if not (self._base_folder / "devkit").is_dir(): | |
return False | |
return self._annotations_mat_path.exists() and self._images_base_path.is_dir() | |