Image Classification
Transformers
English
art
File size: 3,539 Bytes
0dab8bf
 
 
 
 
 
 
5b16cb6
6151861
5b16cb6
 
0dab8bf
94ce1e2
 
 
 
 
 
 
5b16cb6
 
 
 
 
 
0dab8bf
 
 
 
 
 
 
 
6151861
0dab8bf
 
 
 
 
 
 
0bbf37a
 
 
5b16cb6
 
0dab8bf
 
 
 
 
 
 
5b16cb6
0dab8bf
 
94ce1e2
 
 
 
 
 
6151861
0dab8bf
5b16cb6
 
0dab8bf
 
5b16cb6
 
0dab8bf
 
5b16cb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6151861
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score, recall_score
import matplotlib.pyplot as plt
import seaborn as sns

# 驻讜谞拽爪讬讛 诇讛爪讙转 住专讟讜谉
def display_video(video_url):
    video_html = f'''
    <iframe width="560" height="315" src="{video_url}" frameborder="0" allowfullscreen></iframe>
    '''
    # 讛谞讞 讗转 讛-HTML 讘讚砖讘讜专讚 砖诇讱
    return video_html

def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
    shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True)
    train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state)
    return train_df, val_df    

if __name__ == "__main__":
    # Check for GPU availability
    device = torch.device('cuda')

    # Load the pre-trained ViT model and move it to GPU
    model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)

    model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
    
    # Define the image preprocessing pipeline
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    # Load the test dataset
    test_real_folder = 'datasets/test_set/real/'
    test_fake_folder = 'datasets/test_set/fake/'
    
    test_set = labeling(test_real_folder, test_fake_folder)
    test_dataset = CustomDataset(test_set, transform=preprocess)
    test_loader = DataLoader(test_dataset, batch_size=32)

    # Load the trained model
    model.load_state_dict(torch.load('trained_model.pth'))

    # Evaluate the model
    model.eval()
    true_labels = []
    predicted_labels = []

    # 拽讬砖讜专 诇住专讟讜谉
    video_url = 'https://youtube.com/shorts/vGRq060nPYU?feature=share'  # 讛讞诇讬驻讬 讻讗谉 注诐 讛-URL 砖诇 讛住专讟讜谉 砖诇讱
    video_html = display_video(video_url)

    # 讛专讗讬 讗转 讛住专讟讜谉 诇驻谞讬 讛讞讬讝讜讬
    print(video_html)  # 讛爪讙 讗转 讛-HTML 讘讚砖讘讜专讚 砖诇讱

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            logits = outputs.logits  # Extract logits from the output
            _, predicted = torch.max(logits, 1)
            true_labels.extend(labels.cpu().numpy())
            predicted_labels.extend(predicted.cpu().numpy())

    # Calculate evaluation metrics
    accuracy = accuracy_score(true_labels, predicted_labels)
    precision = precision_score(true_labels, predicted_labels)
    cm = confusion_matrix(true_labels, predicted_labels)
    f1 = f1_score(true_labels, predicted_labels)
    ap = average_precision_score(true_labels, predicted_labels)
    recall = recall_score(true_labels, predicted_labels)

    print(f"Test Accuracy: {accuracy:.2%}")
    print(f"Precision: {precision:.2%}")
    print(f"F1 Score: {f1:.2%}")
    print(f"Average Precision: {ap:.2%}")
    print(f"Recall: {recall:.2%}")

    # Plot the confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.show()