MNCJihun's picture
init
25322fb
import os
import matplotlib.pyplot as plt
from pandas.core.common import flatten
import copy
import numpy as np
import random
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import glob
from tqdm import tqdm
import random
class MotorbikeDataset(torch.utils.data.Dataset):
def __init__(self, image_paths, transform=None):
self.root = image_paths
self.image_paths = os.listdir(image_paths)
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_filepath = self.image_paths[idx]
image = cv2.imread(os.path.join(self.root,image_filepath))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
label = int('t' in image_filepath)
if self.transform is not None:
image = self.transform(image=image)["image"]
return image, label
class MotorbikeDataset_CV(torch.utils.data.Dataset):
def __init__(self, root, train_transforms, val_transforms, trainval_ratio=0.8) -> None:
self.root = root
self.train_transforms = train_transforms
self.val_transforms = val_transforms
self.trainval_ratio = trainval_ratio
self.train_split, self.val_split = self.gen_split()
def __len__(self):
return len(self.root)
def gen_split(self):
img_list = os.listdir(self.root)
n_list = [img for img in img_list if img.startswith('n_')]
t_list = [img for img in img_list if img.startswith('t_')]
n_train = random.choices(n_list, k=int(len(n_list)*self.trainval_ratio))
t_train = random.choices(t_list, k=int(len(t_list)*self.trainval_ratio))
n_val = [img for img in n_list if img not in n_train]
t_val = [img for img in t_list if img not in t_train]
train_split = n_train + t_train
val_split = n_val + t_val
return train_split, val_split
def get_split(self):
train_dataset = Dataset_from_list(self.root, self.train_split, self.train_transforms)
val_dataset = Dataset_from_list(self.root, self.val_split, self.val_transforms)
return train_dataset, val_dataset
class Dataset_from_list(torch.utils.data.Dataset):
def __init__(self, root, img_list, transform) -> None:
self.root = root
self.img_list = img_list
self.transform = transform
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
image = cv2.imread(os.path.join(self.root, self.img_list[idx]))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
label = int(self.img_list[idx].startswith('t_'))
if self.transform is not None:
image = self.transform(image=image)["image"]
return image, label