Image Classification
Transformers
English
art
litav commited on
Commit
b74cec0
verified
1 Parent(s): 12d2ff6

Update vit_model_test.py

Browse files
Files changed (1) hide show
  1. vit_model_test.py +13 -18
vit_model_test.py CHANGED
@@ -3,24 +3,28 @@ import torch.nn as nn
3
  from torch.utils.data import Dataset, DataLoader
4
  from torchvision import transforms
5
  from transformers import ViTForImageClassification
 
6
  import os
7
  import pandas as pd
8
  from sklearn.model_selection import train_test_split
9
- from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score, recall_score
10
  import matplotlib.pyplot as plt
11
  import seaborn as sns
 
 
12
 
13
- # 驻讜谞拽爪讬讛 诇讛讞讝专转 HTML 砖诇 住专讟讜谉
14
- def display_video(video_url):
15
- return f'''
16
- <iframe width="640" height="480" src="{video_url}" frameborder="0" allowfullscreen></iframe>
17
- '''
18
 
 
19
  def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
 
20
  shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True)
 
 
21
  train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state)
 
22
  return train_df, val_df
23
 
 
24
  if __name__ == "__main__":
25
  # Check for GPU availability
26
  device = torch.device('cuda')
@@ -28,8 +32,9 @@ if __name__ == "__main__":
28
  # Load the pre-trained ViT model and move it to GPU
29
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
30
 
 
 
31
  model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
32
-
33
  # Define the image preprocessing pipeline
34
  preprocess = transforms.Compose([
35
  transforms.Resize((224, 224)),
@@ -52,20 +57,9 @@ if __name__ == "__main__":
52
  true_labels = []
53
  predicted_labels = []
54
 
55
- # 拽讬砖讜专 诇住专讟讜谉 讘讬讜讟讬讜讘
56
- video_url = 'https://www.youtube.com/embed/vGRq060nPYU' # 讛讞诇祝 讘-URL 砖诇 讛住专讟讜谉 砖诇讱
57
- video_html = display_video(video_url)
58
-
59
- # 讛专讗讬 讗转 讛住专讟讜谉 诇驻谞讬 讛讞讬讝讜讬
60
- print(video_html) # 讝讛 讗诪讜专 诇讛爪讬讙 讗转 讛-HTML 讘讚砖讘讜专讚 砖诇讱
61
-
62
  with torch.no_grad():
63
  for images, labels in test_loader:
64
  images, labels = images.to(device), labels.to(device)
65
-
66
- # 讛专讗讛 讗转 讛住专讟讜谉 讘注转 讞讬讝讜讬
67
- print(video_html) # 讛爪讙 讗转 讛-HTML 砖诇 讛住专讟讜谉
68
-
69
  outputs = model(images)
70
  logits = outputs.logits # Extract logits from the output
71
  _, predicted = torch.max(logits, 1)
@@ -80,6 +74,7 @@ if __name__ == "__main__":
80
  ap = average_precision_score(true_labels, predicted_labels)
81
  recall = recall_score(true_labels, predicted_labels)
82
 
 
83
  print(f"Test Accuracy: {accuracy:.2%}")
84
  print(f"Precision: {precision:.2%}")
85
  print(f"F1 Score: {f1:.2%}")
 
3
  from torch.utils.data import Dataset, DataLoader
4
  from torchvision import transforms
5
  from transformers import ViTForImageClassification
6
+ from PIL import Image
7
  import os
8
  import pandas as pd
9
  from sklearn.model_selection import train_test_split
10
+ from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score
11
  import matplotlib.pyplot as plt
12
  import seaborn as sns
13
+ from sklearn.metrics import recall_score
14
+ from vit_model_traning import labeling,CustomDataset
15
 
 
 
 
 
 
16
 
17
+
18
  def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
19
+ # Shuffle the DataFrame
20
  shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True)
21
+
22
+ # Split the DataFrame into train and validation sets
23
  train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state)
24
+
25
  return train_df, val_df
26
 
27
+
28
  if __name__ == "__main__":
29
  # Check for GPU availability
30
  device = torch.device('cuda')
 
32
  # Load the pre-trained ViT model and move it to GPU
33
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
34
 
35
+
36
+
37
  model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
 
38
  # Define the image preprocessing pipeline
39
  preprocess = transforms.Compose([
40
  transforms.Resize((224, 224)),
 
57
  true_labels = []
58
  predicted_labels = []
59
 
 
 
 
 
 
 
 
60
  with torch.no_grad():
61
  for images, labels in test_loader:
62
  images, labels = images.to(device), labels.to(device)
 
 
 
 
63
  outputs = model(images)
64
  logits = outputs.logits # Extract logits from the output
65
  _, predicted = torch.max(logits, 1)
 
74
  ap = average_precision_score(true_labels, predicted_labels)
75
  recall = recall_score(true_labels, predicted_labels)
76
 
77
+
78
  print(f"Test Accuracy: {accuracy:.2%}")
79
  print(f"Precision: {precision:.2%}")
80
  print(f"F1 Score: {f1:.2%}")