Image Classification
Transformers
English
art
File size: 4,099 Bytes
0dab8bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0012df3
 
0dab8bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification
from PIL import Image
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import recall_score

def labeling(path_real, path_fake):
    image_paths = []
    labels = []

    for filename in os.listdir(path_real):
        image_paths.append(os.path.join(path_real, filename))
        labels.append(0)

    for filename in os.listdir(path_fake):
        image_paths.append(os.path.join(path_fake, filename))
        labels.append(1)

    dataset = pd.DataFrame({'image_path': image_paths, 'label': labels})

    return dataset

class CustomDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        image_path = self.dataframe.iloc[idx, 0]  # Image path is in the first column
        image = Image.open(image_path).convert('RGB')  # Convert to RGB format

        if self.transform:
            image = self.transform(image)

        label = self.dataframe.iloc[idx, 1]  # Label is in the second column
        return image, label
    
def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
    # Shuffle the DataFrame
    shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True)
    
    # Split the DataFrame into train and validation sets
    train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state)
    
    return train_df, val_df    


if __name__ == "__main__":
    # Check for GPU availability
    device = torch.device('cuda')

    # Load the pre-trained ViT model and move it to GPU
    model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)



    model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
    # Define the image preprocessing pipeline
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    # Load the test dataset
    test_real_folder = 'datasets/test_set/real'
    test_fake_folder = 'datasets/test_set/fake'
    test_set = labeling(test_real_folder, test_fake_folder)
    test_dataset = CustomDataset(test_set, transform=preprocess)
    test_loader = DataLoader(test_dataset, batch_size=32)

    # Load the trained model
    model.load_state_dict(torch.load('trained_model.pth'))

    # Evaluate the model
    model.eval()
    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            logits = outputs.logits  # Extract logits from the output
            _, predicted = torch.max(logits, 1)
            true_labels.extend(labels.cpu().numpy())
            predicted_labels.extend(predicted.cpu().numpy())

    # Calculate evaluation metrics
    accuracy = accuracy_score(true_labels, predicted_labels)
    precision = precision_score(true_labels, predicted_labels)
    cm = confusion_matrix(true_labels, predicted_labels)
    f1 = f1_score(true_labels, predicted_labels)
    ap = average_precision_score(true_labels, predicted_labels)
    recall = recall_score(true_labels, predicted_labels)


    print(f"Test Accuracy: {accuracy:.2%}")
    print(f"Precision: {precision:.2%}")
    print(f"F1 Score: {f1:.2%}")
    print(f"Average Precision: {ap:.2%}")
    print(f"Recall: {recall:.2%}")

    # Plot the confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.show()