import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torchvision import transforms from transformers import ViTForImageClassification from PIL import Image import os import pandas as pd from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import recall_score from vit_model_training import labeling,CustomDataset def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59): # Shuffle the DataFrame shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True) # Split the DataFrame into train and validation sets train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state) return train_df, val_df if __name__ == "__main__": # Check for GPU availability device = torch.device('cuda') #this code runs only with nvidia gpu # Load the pre-trained ViT model and move it to GPU model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device) model.classifier = nn.Linear(model.config.hidden_size, 2).to(device) # resize image and make it a tensor (add dimension) preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) # Load the test dataset test_real_folder = 'datasets/test_set/real/' test_fake_folder = 'datasets/test_set/fake/' test_set = labeling(test_real_folder, test_fake_folder) test_dataset = CustomDataset(test_set, transform=preprocess) test_loader = DataLoader(test_dataset, batch_size=32) # Load the trained model model.load_state_dict(torch.load('trained_model.pth')) # Evaluate the model model.eval() true_labels = [] predicted_labels = [] with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) logits = outputs.logits # Extract logits from the output _, predicted = torch.max(logits, 1) true_labels.extend(labels.cpu().numpy()) predicted_labels.extend(predicted.cpu().numpy()) # Calculate evaluation metrics accuracy = accuracy_score(true_labels, predicted_labels) precision = precision_score(true_labels, predicted_labels) cm = confusion_matrix(true_labels, predicted_labels) f1 = f1_score(true_labels, predicted_labels) ap = average_precision_score(true_labels, predicted_labels) recall = recall_score(true_labels, predicted_labels) print(f"Test Accuracy: {accuracy:.2%}") print(f"Precision: {precision:.2%}") print(f"F1 Score: {f1:.2%}") print(f"Average Precision: {ap:.2%}") print(f"Recall: {recall:.2%}") # Plot the confusion matrix plt.figure(figsize=(8, 6)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False) plt.xlabel('Predicted Labels') plt.ylabel('True Labels') plt.title('Confusion Matrix') plt.show()