File size: 4,840 Bytes
0dab8bf 5b16cb6 6151861 5b16cb6 0dab8bf 14cad36 94ce1e2 14cad36 94ce1e2 5b16cb6 0dab8bf 6151861 0dab8bf 0bbf37a 5b16cb6 0dab8bf 5b16cb6 0dab8bf 94ce1e2 acdba2f 94ce1e2 14cad36 6151861 0dab8bf 5b16cb6 14cad36 0dab8bf 5b16cb6 0dab8bf 5b16cb6 6151861 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score, recall_score
import matplotlib.pyplot as plt
import seaborn as sns
# 驻讜谞拽爪讬讛 诇讛讞讝专转 HTML 砖诇 住专讟讜谉
def display_video(video_url):
return f'''
<video width="640" height="480" controls autoplay>
<source src="{video_url}" type="video/mp4">
Your browser does not support the video tag.
</video>
'''
def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True)
train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state)
return train_df, val_df
if __name__ == "__main__":
# Check for GPU availability
device = torch.device('cuda')
# Load the pre-trained ViT model and move it to GPU
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
# Define the image preprocessing pipeline
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Load the test dataset
test_real_folder = 'datasets/test_set/real/'
test_fake_folder = 'datasets/test_set/fake/'
test_set = labeling(test_real_folder, test_fake_folder)
test_dataset = CustomDataset(test_set, transform=preprocess)
test_loader = DataLoader(test_dataset, batch_size=32)
# Load the trained model
model.load_state_dict(torch.load('trained_model.pth'))
# Evaluate the model
model.eval()
true_labels = []
predicted_labels = []
# 拽讬砖讜专 诇住专讟讜谉
video_url = 'https://rr5---sn-33uxaxjvh-aixe.googlevideo.com/videoplayback?expire=1727025979&ei=2_7vZrzMAuGdp-oPuaTo-QI&ip=39.62.1.120&id=o-AJ04-wA4jR6nhlg7B-yNUOXEwR7yoNlJetni5NaAoWRl&itag=134&aitags=133%2C134%2C135%2C136%2C137%2C160%2C242%2C243%2C244%2C247%2C248%2C278&source=youtube&requiressl=yes&xpc=EgVo2aDSNQ%3D%3D&mh=9z&mm=31%2C29&mn=sn-33uxaxjvh-aixe%2Csn-hju7enll&ms=au%2Crdu&mv=m&mvi=5&pl=24&initcwndbps=306250&bui=AXLXGFS9xlNb5y-figGb1FTTN1Ma8zVRiN7RtpZjebiJgICl7QFK5ab9UDZVXvwn2GwOYj4m4rXuQlYc&spc=54MbxY0qT7L8eXI7eMdKq6id860EyvqxATj5F0MLSzmNFdC1mD-XNkZUkcL1EWQ&vprv=1&svpuc=1&mime=video%2Fmp4&ns=qN73Wubd4RAEtRCu3S2dItYQ&rqh=1&gir=yes&clen=242900&dur=5.000&lmt=1727003020660351&mt=1727004294&fvip=2&keepalive=yes&fexp=51299152&c=WEB&sefc=1&txp=630A224&n=Y84SAecGmAZzwg&sparams=expire%2Cei%2Cip%2Cid%2Caitags%2Csource%2Crequiressl%2Cxpc%2Cbui%2Cspc%2Cvprv%2Csvpuc%2Cmime%2Cns%2Crqh%2Cgir%2Cclen%2Cdur%2Clmt&sig=AJfQdSswRAIgGjjE8lnq2bVWML91M2fA0A3qtumgsH-bASH-qjraIRwCIBh9oYh7GnjGwTNescuIZ1qgv4PBj0WOzJJbveuTUOb8&lsparams=mh%2Cmm%2Cmn%2Cms%2Cmv%2Cmvi%2Cpl%2Cinitcwndbps&lsig=ABPmVW0wRQIgCwVg3G31n-JXtH0t66MDGpnLR8s-mRwiTjMQP9TeTawCIQC2zaC1iwicMoTjn6ha46-W1UZrW6Rv9D8HP5I96C1hfg%3D%3D&extt=mp4' # 讛讞诇讬驻讬 讻讗谉 注诐 讛-URL 砖诇 讛住专讟讜谉 砖诇讱
video_html = display_video(video_url)
# 讛专讗讬 讗转 讛住专讟讜谉 诇驻谞讬 讛讞讬讝讜讬
print(video_html) # 讝讛 讗诪讜专 诇讛爪讬讙 讗转 讛-HTML 讘讚砖讘讜专讚 砖诇讱
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
# 讛专讗讛 讗转 讛住专讟讜谉 讘注转 讞讬讝讜讬
print(video_html) # 讛爪讙 讗转 讛-HTML 砖诇 讛住专讟讜谉
outputs = model(images)
logits = outputs.logits # Extract logits from the output
_, predicted = torch.max(logits, 1)
true_labels.extend(labels.cpu().numpy())
predicted_labels.extend(predicted.cpu().numpy())
# Calculate evaluation metrics
accuracy = accuracy_score(true_labels, predicted_labels)
precision = precision_score(true_labels, predicted_labels)
cm = confusion_matrix(true_labels, predicted_labels)
f1 = f1_score(true_labels, predicted_labels)
ap = average_precision_score(true_labels, predicted_labels)
recall = recall_score(true_labels, predicted_labels)
print(f"Test Accuracy: {accuracy:.2%}")
print(f"Precision: {precision:.2%}")
print(f"F1 Score: {f1:.2%}")
print(f"Average Precision: {ap:.2%}")
print(f"Recall: {recall:.2%}")
# Plot the confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.show()
|