File size: 5,187 Bytes
eb6241e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2dc8dc5
eb6241e
 
 
 
 
2dc8dc5
eb6241e
 
 
 
 
 
 
83b587f
eb6241e
2dc8dc5
eb6241e
 
2dc8dc5
eb6241e
 
 
 
 
 
2dc8dc5
eb6241e
 
2dc8dc5
eb6241e
 
2dc8dc5
eb6241e
 
 
 
 
 
 
 
 
 
 
 
fa80d4b
 
 
 
 
 
 
 
 
 
 
 
 
a81ac51
eb6241e
7ac0794
37b74b4
fa80d4b
 
 
 
a81ac51
d1c3f22
1310ac0
eb6241e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2dc8dc5
eb6241e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83b587f
eb6241e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification
from PIL import Image
import torch.optim as optim
import os
import pandas as pd
from sklearn.model_selection import train_test_split


class CustomDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        image_path = self.dataframe.iloc[idx, 0]  
        image = Image.open(image_path).convert('RGB')  # Convert to RGB format

        if self.transform:
            image = self.transform(image)

        label = self.dataframe.iloc[idx, 1]  
        return image, label

def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
    shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True)
    train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state)
    return train_df, val_df

class Custom_VIT_Model:
    def __init__(self):
        # Use gpu if exist (nvidia only) else cpu (any)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Load the pre-trained ViT model
        self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device)

        # Freeze pre-trained layers
        for param in self.model.parameters():
            param.requires_grad = False

        # Define a new classifier that has 2 outputs (0,1)
        self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)

        # Set optimizer
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)

        # Set the image preprocessing (resize image) and make it tensor ( Tensor - add a dimension )
        self.preprocess = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

        # Initialize DataFrame for user data
        self.data_file = 'user_data.csv'
        if os.path.exists(self.data_file):
            self.df = pd.read_csv(self.data_file)
        else:
            self.df = pd.DataFrame(columns=['image_path', 'label'])

    def add_data(self, image_path: str, label: int):
        # Create a new DataFrame entry
        new_entry = pd.DataFrame({'image_path': [image_path], 'label': [label]})
        
        # Append the new entry to the existing DataFrame
        self.df = pd.concat([self.df, new_entry], ignore_index=True)
        
        # Save the updated DataFrame to the specified CSV file
        self.df.to_csv(self.data_file, index=False)

        # Print the current state of the training data for debugging
        print("Current training data:")
        print(self.df)


        

        # Check if we have 100 images for retraining
        if len(self.df) >= 100:
            print("Retraining the model as we have enough data.")
            self.retrain_model()
    

        

    def retrain_model(self):
        # Shuffle and split the data
        train_df, val_df = shuffle_and_split_data(self.df)

        # Define the dataset and dataloaders
        train_dataset = CustomDataset(train_df, transform=self.preprocess)
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

        val_dataset = CustomDataset(val_df, transform=self.preprocess)
        val_loader = DataLoader(val_dataset, batch_size=32)

        # Define the loss function
        criterion = nn.CrossEntropyLoss().to(self.device)

        # Training loop 
        num_epochs = 10
        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0
            for images, labels in train_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                self.optimizer.zero_grad()
                outputs = self.model(images)
                logits = outputs.logits  # Extract logits from the output
                loss = criterion(logits, labels)
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item()
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")

            # Validation loop
            self.model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for images, labels in val_loader:
                    images, labels = images.to(self.device), labels.to(self.device)
                    outputs = self.model(images)
                    logits = outputs.logits
                    _, predicted = torch.max(logits, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            print(f"Validation Accuracy: {correct / total}")

        # Save the retrained model
        torch.save(self.model.state_dict(), 'trained_model.pth')
        print("Model retrained and updated!")

if __name__ == "__main__":
    custom_model = Custom_VIT_Model()