Spaces:
Sleeping
Sleeping
File size: 4,963 Bytes
9b896f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import pathlib
from typing import Callable, Optional, Any, Tuple
import numpy as np
import pandas as pd
from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive, download_url
class StanfordCarsClass(VisionDataset):
"""`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset
The Cars dataset contains 16,185 images of 196 classes of cars. The data is
split into 8,144 training images and 8,041 testing images, where each class
has been split roughly in a 50-50 split
.. note::
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
Args:
root (string): Root directory of dataset
split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again."""
root = pathlib.Path.home() / "tmp" / "Datasets" / "StanfordCars"
def __init__(
self,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
) -> None:
try:
import scipy.io as sio
except ImportError:
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
super().__init__(self.root, transform=transform, target_transform=target_transform)
self.train = train
self._base_folder = pathlib.Path(self.root) / "stanford_cars"
devkit = self._base_folder / "devkit"
if train:
self._annotations_mat_path = devkit / "cars_train_annos.mat"
self._images_base_path = self._base_folder / "cars_train"
else:
self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
self._images_base_path = self._base_folder / "cars_test"
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
self.samples = [
(
str(self._images_base_path / annotation["fname"]),
annotation["class"] - 1, # Original target mapping starts from 1, hence -1
)
for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
]
self.targets = np.array([x[1] for x in self.samples])
self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
"""Returns pil_image and class_id for given index"""
image_path, target = self.samples[idx]
pil_image = Image.open(image_path).convert("RGB")
if self.transform is not None:
pil_image = self.transform(pil_image)
if self.target_transform is not None:
target = self.target_transform(target)
return pil_image, target
def download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(
url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
download_root=str(self._base_folder),
md5="c3b158d763b6e2245038c8ad08e45376",
)
if self.train:
download_and_extract_archive(
url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
download_root=str(self._base_folder),
md5="065e5b463ae28d29e77c1b4b166cfe61",
)
else:
download_and_extract_archive(
url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
download_root=str(self._base_folder),
md5="4ce7ebf6a94d07f1952d94dd34c4d501",
)
download_url(
url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
root=str(self._base_folder),
md5="b0a2b23655a3edd16d84508592a98d10",
)
def _check_exists(self) -> bool:
if not (self._base_folder / "devkit").is_dir():
return False
return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
|