File size: 3,138 Bytes
25322fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import matplotlib.pyplot as plt
from pandas.core.common import flatten
import copy
import numpy as np
import random

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2

import glob
from tqdm import tqdm
import random

class MotorbikeDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, transform=None):
        self.root = image_paths
        self.image_paths = os.listdir(image_paths)
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_filepath = self.image_paths[idx]
        
        image = cv2.imread(os.path.join(self.root,image_filepath))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        label = int('t' in image_filepath)
        if self.transform is not None:
            image = self.transform(image=image)["image"]
        
        return image, label
    

class MotorbikeDataset_CV(torch.utils.data.Dataset):
    def __init__(self, root, train_transforms, val_transforms, trainval_ratio=0.8) -> None:
        self.root = root
        self.train_transforms = train_transforms
        self.val_transforms = val_transforms
        self.trainval_ratio = trainval_ratio
        self.train_split, self.val_split = self.gen_split()
        
    def __len__(self):
        return len(self.root)

    def gen_split(self):
        img_list = os.listdir(self.root)
        n_list = [img for img in img_list if img.startswith('n_')]
        t_list = [img for img in img_list if img.startswith('t_')]
        
        n_train = random.choices(n_list, k=int(len(n_list)*self.trainval_ratio))
        t_train = random.choices(t_list, k=int(len(t_list)*self.trainval_ratio))
        n_val = [img for img in n_list if img not in n_train]
        t_val = [img for img in t_list if img not in t_train]
        
        train_split = n_train + t_train
        val_split = n_val + t_val
        return train_split, val_split

    def get_split(self):
        train_dataset = Dataset_from_list(self.root, self.train_split, self.train_transforms)
        val_dataset = Dataset_from_list(self.root, self.val_split, self.val_transforms)
        return train_dataset, val_dataset
    
class Dataset_from_list(torch.utils.data.Dataset):
    def __init__(self, root, img_list, transform) -> None:
        self.root = root
        self.img_list = img_list
        self.transform = transform
    
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        image = cv2.imread(os.path.join(self.root, self.img_list[idx]))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        label = int(self.img_list[idx].startswith('t_'))
        
        if self.transform is not None:
            image = self.transform(image=image)["image"]
            
        return image, label