|
import pathlib |
|
from typing import Callable, Optional, Any, Tuple |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from PIL import Image |
|
from torchvision.datasets import VisionDataset |
|
from torchvision.datasets.utils import download_and_extract_archive, download_url |
|
|
|
|
|
class StanfordCarsClass(VisionDataset): |
|
"""`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset |
|
|
|
The Cars dataset contains 16,185 images of 196 classes of cars. The data is |
|
split into 8,144 training images and 8,041 testing images, where each class |
|
has been split roughly in a 50-50 split |
|
|
|
.. note:: |
|
|
|
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format. |
|
|
|
Args: |
|
root (string): Root directory of dataset |
|
split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. |
|
transform (callable, optional): A function/transform that takes in an PIL image |
|
and returns a transformed version. E.g, ``transforms.RandomCrop`` |
|
target_transform (callable, optional): A function/transform that takes in the |
|
target and transforms it. |
|
download (bool, optional): If True, downloads the dataset from the internet and |
|
puts it in root directory. If dataset is already downloaded, it is not |
|
downloaded again.""" |
|
root = pathlib.Path.home() / "tmp" / "Datasets" / "StanfordCars" |
|
def __init__( |
|
self, |
|
train: bool = True, |
|
transform: Optional[Callable] = None, |
|
target_transform: Optional[Callable] = None, |
|
download: bool = True, |
|
) -> None: |
|
|
|
try: |
|
import scipy.io as sio |
|
except ImportError: |
|
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") |
|
|
|
super().__init__(self.root, transform=transform, target_transform=target_transform) |
|
|
|
self.train = train |
|
self._base_folder = pathlib.Path(self.root) / "stanford_cars" |
|
devkit = self._base_folder / "devkit" |
|
|
|
if train: |
|
self._annotations_mat_path = devkit / "cars_train_annos.mat" |
|
self._images_base_path = self._base_folder / "cars_train" |
|
else: |
|
self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat" |
|
self._images_base_path = self._base_folder / "cars_test" |
|
|
|
if download: |
|
self.download() |
|
|
|
if not self._check_exists(): |
|
raise RuntimeError("Dataset not found. You can use download=True to download it") |
|
|
|
self.samples = [ |
|
( |
|
str(self._images_base_path / annotation["fname"]), |
|
annotation["class"] - 1, |
|
) |
|
for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] |
|
] |
|
self.targets = np.array([x[1] for x in self.samples]) |
|
self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() |
|
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} |
|
|
|
def __len__(self) -> int: |
|
return len(self.samples) |
|
|
|
def __getitem__(self, idx: int) -> Tuple[Any, Any]: |
|
"""Returns pil_image and class_id for given index""" |
|
image_path, target = self.samples[idx] |
|
pil_image = Image.open(image_path).convert("RGB") |
|
|
|
if self.transform is not None: |
|
pil_image = self.transform(pil_image) |
|
if self.target_transform is not None: |
|
target = self.target_transform(target) |
|
return pil_image, target |
|
|
|
def download(self) -> None: |
|
if self._check_exists(): |
|
return |
|
|
|
download_and_extract_archive( |
|
url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz", |
|
download_root=str(self._base_folder), |
|
md5="c3b158d763b6e2245038c8ad08e45376", |
|
) |
|
if self.train: |
|
download_and_extract_archive( |
|
url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz", |
|
download_root=str(self._base_folder), |
|
md5="065e5b463ae28d29e77c1b4b166cfe61", |
|
) |
|
else: |
|
download_and_extract_archive( |
|
url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz", |
|
download_root=str(self._base_folder), |
|
md5="4ce7ebf6a94d07f1952d94dd34c4d501", |
|
) |
|
download_url( |
|
url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat", |
|
root=str(self._base_folder), |
|
md5="b0a2b23655a3edd16d84508592a98d10", |
|
) |
|
|
|
def _check_exists(self) -> bool: |
|
if not (self._base_folder / "devkit").is_dir(): |
|
return False |
|
|
|
return self._annotations_mat_path.exists() and self._images_base_path.is_dir() |
|
|