|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
from transformers import ViTForImageClassification |
|
from PIL import Image |
|
import os |
|
import pandas as pd |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from sklearn.metrics import recall_score |
|
from vit_model_training import labeling,CustomDataset |
|
|
|
|
|
|
|
def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59): |
|
|
|
shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True) |
|
|
|
|
|
train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state) |
|
|
|
return train_df, val_df |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
device = torch.device('cuda') |
|
|
|
|
|
|
|
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device) |
|
|
|
|
|
|
|
model.classifier = nn.Linear(model.config.hidden_size, 2).to(device) |
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor() |
|
]) |
|
|
|
|
|
test_real_folder = 'datasets/test_set/real/' |
|
test_fake_folder = 'datasets/test_set/fake/' |
|
|
|
test_set = labeling(test_real_folder, test_fake_folder) |
|
test_dataset = CustomDataset(test_set, transform=preprocess) |
|
test_loader = DataLoader(test_dataset, batch_size=32) |
|
|
|
|
|
model.load_state_dict(torch.load('trained_model.pth')) |
|
|
|
|
|
model.eval() |
|
true_labels = [] |
|
predicted_labels = [] |
|
|
|
with torch.no_grad(): |
|
for images, labels in test_loader: |
|
images, labels = images.to(device), labels.to(device) |
|
outputs = model(images) |
|
logits = outputs.logits |
|
_, predicted = torch.max(logits, 1) |
|
true_labels.extend(labels.cpu().numpy()) |
|
predicted_labels.extend(predicted.cpu().numpy()) |
|
|
|
|
|
accuracy = accuracy_score(true_labels, predicted_labels) |
|
precision = precision_score(true_labels, predicted_labels) |
|
cm = confusion_matrix(true_labels, predicted_labels) |
|
f1 = f1_score(true_labels, predicted_labels) |
|
ap = average_precision_score(true_labels, predicted_labels) |
|
recall = recall_score(true_labels, predicted_labels) |
|
|
|
|
|
print(f"Test Accuracy: {accuracy:.2%}") |
|
print(f"Precision: {precision:.2%}") |
|
print(f"F1 Score: {f1:.2%}") |
|
print(f"Average Precision: {ap:.2%}") |
|
print(f"Recall: {recall:.2%}") |
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False) |
|
plt.xlabel('Predicted Labels') |
|
plt.ylabel('True Labels') |
|
plt.title('Confusion Matrix') |
|
plt.show() |
|
|