|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
from transformers import ViTForImageClassification |
|
import os |
|
import pandas as pd |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score, recall_score |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
|
|
|
|
def display_video(video_url): |
|
return f''' |
|
<video width="640" height="480" controls autoplay> |
|
<source src="{video_url}" type="video/mp4"> |
|
Your browser does not support the video tag. |
|
</video> |
|
''' |
|
|
|
def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59): |
|
shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True) |
|
train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state) |
|
return train_df, val_df |
|
|
|
if __name__ == "__main__": |
|
|
|
device = torch.device('cuda') |
|
|
|
|
|
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device) |
|
|
|
model.classifier = nn.Linear(model.config.hidden_size, 2).to(device) |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor() |
|
]) |
|
|
|
|
|
test_real_folder = 'datasets/test_set/real/' |
|
test_fake_folder = 'datasets/test_set/fake/' |
|
|
|
test_set = labeling(test_real_folder, test_fake_folder) |
|
test_dataset = CustomDataset(test_set, transform=preprocess) |
|
test_loader = DataLoader(test_dataset, batch_size=32) |
|
|
|
|
|
model.load_state_dict(torch.load('trained_model.pth')) |
|
|
|
|
|
model.eval() |
|
true_labels = [] |
|
predicted_labels = [] |
|
|
|
|
|
video_url = 'https://rr5---sn-33uxaxjvh-aixe.googlevideo.com/videoplayback?expire=1727025979&ei=2_7vZrzMAuGdp-oPuaTo-QI&ip=39.62.1.120&id=o-AJ04-wA4jR6nhlg7B-yNUOXEwR7yoNlJetni5NaAoWRl&itag=134&aitags=133%2C134%2C135%2C136%2C137%2C160%2C242%2C243%2C244%2C247%2C248%2C278&source=youtube&requiressl=yes&xpc=EgVo2aDSNQ%3D%3D&mh=9z&mm=31%2C29&mn=sn-33uxaxjvh-aixe%2Csn-hju7enll&ms=au%2Crdu&mv=m&mvi=5&pl=24&initcwndbps=306250&bui=AXLXGFS9xlNb5y-figGb1FTTN1Ma8zVRiN7RtpZjebiJgICl7QFK5ab9UDZVXvwn2GwOYj4m4rXuQlYc&spc=54MbxY0qT7L8eXI7eMdKq6id860EyvqxATj5F0MLSzmNFdC1mD-XNkZUkcL1EWQ&vprv=1&svpuc=1&mime=video%2Fmp4&ns=qN73Wubd4RAEtRCu3S2dItYQ&rqh=1&gir=yes&clen=242900&dur=5.000&lmt=1727003020660351&mt=1727004294&fvip=2&keepalive=yes&fexp=51299152&c=WEB&sefc=1&txp=630A224&n=Y84SAecGmAZzwg&sparams=expire%2Cei%2Cip%2Cid%2Caitags%2Csource%2Crequiressl%2Cxpc%2Cbui%2Cspc%2Cvprv%2Csvpuc%2Cmime%2Cns%2Crqh%2Cgir%2Cclen%2Cdur%2Clmt&sig=AJfQdSswRAIgGjjE8lnq2bVWML91M2fA0A3qtumgsH-bASH-qjraIRwCIBh9oYh7GnjGwTNescuIZ1qgv4PBj0WOzJJbveuTUOb8&lsparams=mh%2Cmm%2Cmn%2Cms%2Cmv%2Cmvi%2Cpl%2Cinitcwndbps&lsig=ABPmVW0wRQIgCwVg3G31n-JXtH0t66MDGpnLR8s-mRwiTjMQP9TeTawCIQC2zaC1iwicMoTjn6ha46-W1UZrW6Rv9D8HP5I96C1hfg%3D%3D&extt=mp4' |
|
video_html = display_video(video_url) |
|
|
|
|
|
print(video_html) |
|
|
|
with torch.no_grad(): |
|
for images, labels in test_loader: |
|
images, labels = images.to(device), labels.to(device) |
|
|
|
|
|
print(video_html) |
|
|
|
outputs = model(images) |
|
logits = outputs.logits |
|
_, predicted = torch.max(logits, 1) |
|
true_labels.extend(labels.cpu().numpy()) |
|
predicted_labels.extend(predicted.cpu().numpy()) |
|
|
|
|
|
accuracy = accuracy_score(true_labels, predicted_labels) |
|
precision = precision_score(true_labels, predicted_labels) |
|
cm = confusion_matrix(true_labels, predicted_labels) |
|
f1 = f1_score(true_labels, predicted_labels) |
|
ap = average_precision_score(true_labels, predicted_labels) |
|
recall = recall_score(true_labels, predicted_labels) |
|
|
|
print(f"Test Accuracy: {accuracy:.2%}") |
|
print(f"Precision: {precision:.2%}") |
|
print(f"F1 Score: {f1:.2%}") |
|
print(f"Average Precision: {ap:.2%}") |
|
print(f"Recall: {recall:.2%}") |
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False) |
|
plt.xlabel('Predicted Labels') |
|
plt.ylabel('True Labels') |
|
plt.title('Confusion Matrix') |
|
plt.show() |
|
|