Image Classification
Transformers
English
art
litav commited on
Commit
94ce1e2
verified
1 Parent(s): 6151861

Update vit_model_test.py

Browse files
Files changed (1) hide show
  1. vit_model_test.py +13 -25
vit_model_test.py CHANGED
@@ -3,39 +3,26 @@ import torch.nn as nn
3
  from torch.utils.data import Dataset, DataLoader
4
  from torchvision import transforms
5
  from transformers import ViTForImageClassification
6
- from PIL import Image
7
  import os
8
  import pandas as pd
9
  from sklearn.model_selection import train_test_split
10
  from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score, recall_score
11
  import matplotlib.pyplot as plt
12
  import seaborn as sns
13
- import cv2 # 住驻专讬讬转 OpenCV 诇讛爪讙转 讛讜讬讚讗讜
14
- from vit_model_traning import labeling, CustomDataset
15
 
 
 
 
 
 
 
 
16
 
17
  def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
18
  shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True)
19
  train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state)
20
  return train_df, val_df
21
 
22
-
23
- def play_animation(video_path):
24
- cap = cv2.VideoCapture(video_path)
25
-
26
- while cap.isOpened():
27
- ret, frame = cap.read()
28
- if not ret:
29
- break
30
- cv2.imshow('Processing Animation', frame)
31
-
32
- # Press 'q' to exit early
33
- if cv2.waitKey(25) & 0xFF == ord('q'):
34
- break
35
-
36
- cap.release()
37
- cv2.destroyAllWindows()
38
-
39
  if __name__ == "__main__":
40
  # Check for GPU availability
41
  device = torch.device('cuda')
@@ -67,8 +54,12 @@ if __name__ == "__main__":
67
  true_labels = []
68
  predicted_labels = []
69
 
70
- # Play animation while processing
71
- play_animation('https://huggingface.co/DataScienceProject/Vit/blob/main/0001-0120.mp4') # 诪住诇讜诇 诇住专讟讜谉 讛讗谞讬诪爪讬讛 砖诇讱
 
 
 
 
72
 
73
  with torch.no_grad():
74
  for images, labels in test_loader:
@@ -100,6 +91,3 @@ if __name__ == "__main__":
100
  plt.ylabel('True Labels')
101
  plt.title('Confusion Matrix')
102
  plt.show()
103
-
104
- # Play animation again if needed
105
- # play_animation('path_to_your_animation.mp4')
 
3
  from torch.utils.data import Dataset, DataLoader
4
  from torchvision import transforms
5
  from transformers import ViTForImageClassification
 
6
  import os
7
  import pandas as pd
8
  from sklearn.model_selection import train_test_split
9
  from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score, recall_score
10
  import matplotlib.pyplot as plt
11
  import seaborn as sns
 
 
12
 
13
+ # 驻讜谞拽爪讬讛 诇讛爪讙转 住专讟讜谉
14
+ def display_video(video_url):
15
+ video_html = f'''
16
+ <iframe width="560" height="315" src="{video_url}" frameborder="0" allowfullscreen></iframe>
17
+ '''
18
+ # 讛谞讞 讗转 讛-HTML 讘讚砖讘讜专讚 砖诇讱
19
+ return video_html
20
 
21
  def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
22
  shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True)
23
  train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state)
24
  return train_df, val_df
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  if __name__ == "__main__":
27
  # Check for GPU availability
28
  device = torch.device('cuda')
 
54
  true_labels = []
55
  predicted_labels = []
56
 
57
+ # 拽讬砖讜专 诇住专讟讜谉
58
+ video_url = 'https://youtube.com/shorts/vGRq060nPYU?feature=share' # 讛讞诇讬驻讬 讻讗谉 注诐 讛-URL 砖诇 讛住专讟讜谉 砖诇讱
59
+ video_html = display_video(video_url)
60
+
61
+ # 讛专讗讬 讗转 讛住专讟讜谉 诇驻谞讬 讛讞讬讝讜讬
62
+ print(video_html) # 讛爪讙 讗转 讛-HTML 讘讚砖讘讜专讚 砖诇讱
63
 
64
  with torch.no_grad():
65
  for images, labels in test_loader:
 
91
  plt.ylabel('True Labels')
92
  plt.title('Confusion Matrix')
93
  plt.show()