Image Classification
Transformers
English
art
litav commited on
Commit
6151861
verified
1 Parent(s): 9f32d6e

Update vit_model_test.py

Browse files
Files changed (1) hide show
  1. vit_model_test.py +27 -12
vit_model_test.py CHANGED
@@ -7,24 +7,35 @@ from PIL import Image
7
  import os
8
  import pandas as pd
9
  from sklearn.model_selection import train_test_split
10
- from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score
11
  import matplotlib.pyplot as plt
12
  import seaborn as sns
13
- from sklearn.metrics import recall_score
14
- from vit_model_traning import labeling,CustomDataset
15
 
16
 
17
-
18
  def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
19
- # Shuffle the DataFrame
20
  shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True)
21
-
22
- # Split the DataFrame into train and validation sets
23
  train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state)
24
-
25
  return train_df, val_df
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  if __name__ == "__main__":
29
  # Check for GPU availability
30
  device = torch.device('cuda')
@@ -32,9 +43,8 @@ if __name__ == "__main__":
32
  # Load the pre-trained ViT model and move it to GPU
33
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
34
 
35
-
36
-
37
  model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
 
38
  # Define the image preprocessing pipeline
39
  preprocess = transforms.Compose([
40
  transforms.Resize((224, 224)),
@@ -57,6 +67,9 @@ if __name__ == "__main__":
57
  true_labels = []
58
  predicted_labels = []
59
 
 
 
 
60
  with torch.no_grad():
61
  for images, labels in test_loader:
62
  images, labels = images.to(device), labels.to(device)
@@ -74,7 +87,6 @@ if __name__ == "__main__":
74
  ap = average_precision_score(true_labels, predicted_labels)
75
  recall = recall_score(true_labels, predicted_labels)
76
 
77
-
78
  print(f"Test Accuracy: {accuracy:.2%}")
79
  print(f"Precision: {precision:.2%}")
80
  print(f"F1 Score: {f1:.2%}")
@@ -87,4 +99,7 @@ if __name__ == "__main__":
87
  plt.xlabel('Predicted Labels')
88
  plt.ylabel('True Labels')
89
  plt.title('Confusion Matrix')
90
- plt.show()
 
 
 
 
7
  import os
8
  import pandas as pd
9
  from sklearn.model_selection import train_test_split
10
+ from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score, recall_score
11
  import matplotlib.pyplot as plt
12
  import seaborn as sns
13
+ import cv2 # 住驻专讬讬转 OpenCV 诇讛爪讙转 讛讜讬讚讗讜
14
+ from vit_model_traning import labeling, CustomDataset
15
 
16
 
 
17
  def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
 
18
  shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True)
 
 
19
  train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state)
 
20
  return train_df, val_df
21
 
22
 
23
+ def play_animation(video_path):
24
+ cap = cv2.VideoCapture(video_path)
25
+
26
+ while cap.isOpened():
27
+ ret, frame = cap.read()
28
+ if not ret:
29
+ break
30
+ cv2.imshow('Processing Animation', frame)
31
+
32
+ # Press 'q' to exit early
33
+ if cv2.waitKey(25) & 0xFF == ord('q'):
34
+ break
35
+
36
+ cap.release()
37
+ cv2.destroyAllWindows()
38
+
39
  if __name__ == "__main__":
40
  # Check for GPU availability
41
  device = torch.device('cuda')
 
43
  # Load the pre-trained ViT model and move it to GPU
44
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
45
 
 
 
46
  model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
47
+
48
  # Define the image preprocessing pipeline
49
  preprocess = transforms.Compose([
50
  transforms.Resize((224, 224)),
 
67
  true_labels = []
68
  predicted_labels = []
69
 
70
+ # Play animation while processing
71
+ play_animation('https://huggingface.co/DataScienceProject/Vit/blob/main/0001-0120.mp4') # 诪住诇讜诇 诇住专讟讜谉 讛讗谞讬诪爪讬讛 砖诇讱
72
+
73
  with torch.no_grad():
74
  for images, labels in test_loader:
75
  images, labels = images.to(device), labels.to(device)
 
87
  ap = average_precision_score(true_labels, predicted_labels)
88
  recall = recall_score(true_labels, predicted_labels)
89
 
 
90
  print(f"Test Accuracy: {accuracy:.2%}")
91
  print(f"Precision: {precision:.2%}")
92
  print(f"F1 Score: {f1:.2%}")
 
99
  plt.xlabel('Predicted Labels')
100
  plt.ylabel('True Labels')
101
  plt.title('Confusion Matrix')
102
+ plt.show()
103
+
104
+ # Play animation again if needed
105
+ # play_animation('path_to_your_animation.mp4')