Update vit_model_test.py
Browse files- vit_model_test.py +9 -22
vit_model_test.py
CHANGED
@@ -6,24 +6,14 @@ from transformers import ViTForImageClassification
|
|
6 |
import os
|
7 |
import pandas as pd
|
8 |
from sklearn.model_selection import train_test_split
|
9 |
-
from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score
|
10 |
import matplotlib.pyplot as plt
|
11 |
import seaborn as sns
|
12 |
-
from sklearn.metrics import recall_score
|
13 |
-
from vit_model_traning import labeling, CustomDataset
|
14 |
|
15 |
# 驻讜谞拽爪讬讛 诇讛讞讝专转 HTML 砖诇 住专讟讜谉
|
16 |
def display_video(video_url):
|
17 |
return f'''
|
18 |
-
<
|
19 |
-
<video width="640" height="480" controls autoplay>
|
20 |
-
<source src="{video_url}" type="video/mp4">
|
21 |
-
Your browser does not support the video tag.
|
22 |
-
</video>
|
23 |
-
</div>
|
24 |
-
<script>
|
25 |
-
document.getElementById('video-container').style.display = 'block';
|
26 |
-
</script>
|
27 |
'''
|
28 |
|
29 |
def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
|
@@ -39,7 +29,7 @@ if __name__ == "__main__":
|
|
39 |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
|
40 |
|
41 |
model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
|
42 |
-
|
43 |
# Define the image preprocessing pipeline
|
44 |
preprocess = transforms.Compose([
|
45 |
transforms.Resize((224, 224)),
|
@@ -57,20 +47,17 @@ if __name__ == "__main__":
|
|
57 |
# Load the trained model
|
58 |
model.load_state_dict(torch.load('trained_model.pth'))
|
59 |
|
60 |
-
# 拽讬砖讜专 诇住专讟讜谉
|
61 |
-
video_url = '"C:\Users\litav\Downloads\0001-0120.mp4"' # 讛讞诇讬驻讬 讻讗谉 注诐 讛-URL 砖诇 讛住专讟讜谉 砖诇讱
|
62 |
-
video_html = display_video(video_url)
|
63 |
-
|
64 |
-
# 讛专讗讛 讗转 讛住专讟讜谉 讻讗砖专 讛讻驻转讜专 谞诇讞抓
|
65 |
-
print(video_html) # 讝讛 讗诪讜专 诇讛爪讬讙 讗转 讛-HTML 讘讚砖讘讜专讚 砖诇讱
|
66 |
-
|
67 |
# Evaluate the model
|
68 |
model.eval()
|
69 |
true_labels = []
|
70 |
predicted_labels = []
|
71 |
|
72 |
-
#
|
73 |
-
#
|
|
|
|
|
|
|
|
|
74 |
|
75 |
with torch.no_grad():
|
76 |
for images, labels in test_loader:
|
|
|
6 |
import os
|
7 |
import pandas as pd
|
8 |
from sklearn.model_selection import train_test_split
|
9 |
+
from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score, recall_score
|
10 |
import matplotlib.pyplot as plt
|
11 |
import seaborn as sns
|
|
|
|
|
12 |
|
13 |
# 驻讜谞拽爪讬讛 诇讛讞讝专转 HTML 砖诇 住专讟讜谉
|
14 |
def display_video(video_url):
|
15 |
return f'''
|
16 |
+
<iframe width="640" height="480" src="{video_url}" frameborder="0" allowfullscreen></iframe>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
'''
|
18 |
|
19 |
def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59):
|
|
|
29 |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
|
30 |
|
31 |
model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
|
32 |
+
|
33 |
# Define the image preprocessing pipeline
|
34 |
preprocess = transforms.Compose([
|
35 |
transforms.Resize((224, 224)),
|
|
|
47 |
# Load the trained model
|
48 |
model.load_state_dict(torch.load('trained_model.pth'))
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# Evaluate the model
|
51 |
model.eval()
|
52 |
true_labels = []
|
53 |
predicted_labels = []
|
54 |
|
55 |
+
# 拽讬砖讜专 诇住专讟讜谉 讘讬讜讟讬讜讘
|
56 |
+
video_url = 'https://www.youtube.com/embed/vGRq060nPYU' # 讛讞诇祝 讘-URL 砖诇 讛住专讟讜谉 砖诇讱
|
57 |
+
video_html = display_video(video_url)
|
58 |
+
|
59 |
+
# 讛专讗讬 讗转 讛住专讟讜谉 诇驻谞讬 讛讞讬讝讜讬
|
60 |
+
print(video_html) # 讝讛 讗诪讜专 诇讛爪讬讙 讗转 讛-HTML 讘讚砖讘讜专讚 砖诇讱
|
61 |
|
62 |
with torch.no_grad():
|
63 |
for images, labels in test_loader:
|