import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torchvision import transforms from transformers import ViTForImageClassification import os import pandas as pd from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score, recall_score import matplotlib.pyplot as plt import seaborn as sns # פונקציה להחזרת HTML של סרטון def display_video(video_url): return f''' ''' def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59): shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True) train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state) return train_df, val_df if __name__ == "__main__": # Check for GPU availability device = torch.device('cuda') # Load the pre-trained ViT model and move it to GPU model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device) model.classifier = nn.Linear(model.config.hidden_size, 2).to(device) # Define the image preprocessing pipeline preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) # Load the test dataset test_real_folder = 'datasets/test_set/real/' test_fake_folder = 'datasets/test_set/fake/' test_set = labeling(test_real_folder, test_fake_folder) test_dataset = CustomDataset(test_set, transform=preprocess) test_loader = DataLoader(test_dataset, batch_size=32) # Load the trained model model.load_state_dict(torch.load('trained_model.pth')) # Evaluate the model model.eval() true_labels = [] predicted_labels = [] # קישור לסרטון video_url = 'https://rr5---sn-33uxaxjvh-aixe.googlevideo.com/videoplayback?expire=1727025979&ei=2_7vZrzMAuGdp-oPuaTo-QI&ip=39.62.1.120&id=o-AJ04-wA4jR6nhlg7B-yNUOXEwR7yoNlJetni5NaAoWRl&itag=134&aitags=133%2C134%2C135%2C136%2C137%2C160%2C242%2C243%2C244%2C247%2C248%2C278&source=youtube&requiressl=yes&xpc=EgVo2aDSNQ%3D%3D&mh=9z&mm=31%2C29&mn=sn-33uxaxjvh-aixe%2Csn-hju7enll&ms=au%2Crdu&mv=m&mvi=5&pl=24&initcwndbps=306250&bui=AXLXGFS9xlNb5y-figGb1FTTN1Ma8zVRiN7RtpZjebiJgICl7QFK5ab9UDZVXvwn2GwOYj4m4rXuQlYc&spc=54MbxY0qT7L8eXI7eMdKq6id860EyvqxATj5F0MLSzmNFdC1mD-XNkZUkcL1EWQ&vprv=1&svpuc=1&mime=video%2Fmp4&ns=qN73Wubd4RAEtRCu3S2dItYQ&rqh=1&gir=yes&clen=242900&dur=5.000&lmt=1727003020660351&mt=1727004294&fvip=2&keepalive=yes&fexp=51299152&c=WEB&sefc=1&txp=630A224&n=Y84SAecGmAZzwg&sparams=expire%2Cei%2Cip%2Cid%2Caitags%2Csource%2Crequiressl%2Cxpc%2Cbui%2Cspc%2Cvprv%2Csvpuc%2Cmime%2Cns%2Crqh%2Cgir%2Cclen%2Cdur%2Clmt&sig=AJfQdSswRAIgGjjE8lnq2bVWML91M2fA0A3qtumgsH-bASH-qjraIRwCIBh9oYh7GnjGwTNescuIZ1qgv4PBj0WOzJJbveuTUOb8&lsparams=mh%2Cmm%2Cmn%2Cms%2Cmv%2Cmvi%2Cpl%2Cinitcwndbps&lsig=ABPmVW0wRQIgCwVg3G31n-JXtH0t66MDGpnLR8s-mRwiTjMQP9TeTawCIQC2zaC1iwicMoTjn6ha46-W1UZrW6Rv9D8HP5I96C1hfg%3D%3D&extt=mp4' # החליפי כאן עם ה-URL של הסרטון שלך video_html = display_video(video_url) # הראי את הסרטון לפני החיזוי print(video_html) # זה אמור להציג את ה-HTML בדשבורד שלך with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(device), labels.to(device) # הראה את הסרטון בעת חיזוי print(video_html) # הצג את ה-HTML של הסרטון outputs = model(images) logits = outputs.logits # Extract logits from the output _, predicted = torch.max(logits, 1) true_labels.extend(labels.cpu().numpy()) predicted_labels.extend(predicted.cpu().numpy()) # Calculate evaluation metrics accuracy = accuracy_score(true_labels, predicted_labels) precision = precision_score(true_labels, predicted_labels) cm = confusion_matrix(true_labels, predicted_labels) f1 = f1_score(true_labels, predicted_labels) ap = average_precision_score(true_labels, predicted_labels) recall = recall_score(true_labels, predicted_labels) print(f"Test Accuracy: {accuracy:.2%}") print(f"Precision: {precision:.2%}") print(f"F1 Score: {f1:.2%}") print(f"Average Precision: {ap:.2%}") print(f"Recall: {recall:.2%}") # Plot the confusion matrix plt.figure(figsize=(8, 6)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False) plt.xlabel('Predicted Labels') plt.ylabel('True Labels') plt.title('Confusion Matrix') plt.show()