Spaces:
Runtime error
Runtime error
import os | |
import matplotlib.pyplot as plt | |
from pandas.core.common import flatten | |
import copy | |
import numpy as np | |
import random | |
import torch | |
from torch import nn | |
from torch import optim | |
import torch.nn.functional as F | |
from torchvision import datasets, transforms, models | |
from torch.utils.data import Dataset, DataLoader | |
import torch.nn as nn | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
import cv2 | |
import glob | |
from tqdm import tqdm | |
import random | |
class MotorbikeDataset(torch.utils.data.Dataset): | |
def __init__(self, image_paths, transform=None): | |
self.root = image_paths | |
self.image_paths = os.listdir(image_paths) | |
self.transform = transform | |
def __len__(self): | |
return len(self.image_paths) | |
def __getitem__(self, idx): | |
image_filepath = self.image_paths[idx] | |
image = cv2.imread(os.path.join(self.root,image_filepath)) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
label = int('t' in image_filepath) | |
if self.transform is not None: | |
image = self.transform(image=image)["image"] | |
return image, label | |
class MotorbikeDataset_CV(torch.utils.data.Dataset): | |
def __init__(self, root, train_transforms, val_transforms, trainval_ratio=0.8) -> None: | |
self.root = root | |
self.train_transforms = train_transforms | |
self.val_transforms = val_transforms | |
self.trainval_ratio = trainval_ratio | |
self.train_split, self.val_split = self.gen_split() | |
def __len__(self): | |
return len(self.root) | |
def gen_split(self): | |
img_list = os.listdir(self.root) | |
n_list = [img for img in img_list if img.startswith('n_')] | |
t_list = [img for img in img_list if img.startswith('t_')] | |
n_train = random.choices(n_list, k=int(len(n_list)*self.trainval_ratio)) | |
t_train = random.choices(t_list, k=int(len(t_list)*self.trainval_ratio)) | |
n_val = [img for img in n_list if img not in n_train] | |
t_val = [img for img in t_list if img not in t_train] | |
train_split = n_train + t_train | |
val_split = n_val + t_val | |
return train_split, val_split | |
def get_split(self): | |
train_dataset = Dataset_from_list(self.root, self.train_split, self.train_transforms) | |
val_dataset = Dataset_from_list(self.root, self.val_split, self.val_transforms) | |
return train_dataset, val_dataset | |
class Dataset_from_list(torch.utils.data.Dataset): | |
def __init__(self, root, img_list, transform) -> None: | |
self.root = root | |
self.img_list = img_list | |
self.transform = transform | |
def __len__(self): | |
return len(self.img_list) | |
def __getitem__(self, idx): | |
image = cv2.imread(os.path.join(self.root, self.img_list[idx])) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
label = int(self.img_list[idx].startswith('t_')) | |
if self.transform is not None: | |
image = self.transform(image=image)["image"] | |
return image, label | |