jarif commited on
Commit
37c8ef0
·
verified ·
1 Parent(s): 174a0dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -86
app.py CHANGED
@@ -1,86 +1,103 @@
1
- import gradio as gr
2
- import torch
3
- from pytorchvideo.data.encoded_video import EncodedVideo
4
- from torchvision.transforms import Resize
5
- from pytorchvideo.transforms import UniformTemporalSubsample
6
- from transformers import VideoMAEForVideoClassification
7
-
8
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
-
10
- model_path = "model"
11
- loaded_model = VideoMAEForVideoClassification.from_pretrained(model_path)
12
- loaded_model = loaded_model.to(device)
13
- loaded_model.eval()
14
-
15
- label_names = [
16
- 'Archery', 'BalanceBeam', 'BenchPress', 'ApplyEyeMakeup', 'BasketballDunk',
17
- 'BandMarching', 'BabyCrawling', 'ApplyLipstick', 'BaseballPitch', 'Basketball'
18
- ]
19
-
20
- def load_video(video_path):
21
- try:
22
- video = EncodedVideo.from_path(video_path)
23
- video_data = video.get_clip(start_sec=0, end_sec=video.duration)
24
- return video_data['video']
25
- except Exception as e:
26
- raise ValueError(f"Error loading video: {str(e)}")
27
-
28
- def preprocess_video(video_frames):
29
- try:
30
- transform_temporal = UniformTemporalSubsample(16)
31
- video_frames = transform_temporal(video_frames)
32
- video_frames = video_frames / 255.0
33
-
34
- if video_frames.shape[0] == 3:
35
- video_frames = video_frames.permute(1, 0, 2, 3)
36
-
37
- mean = torch.tensor([0.485, 0.456, 0.406])
38
- std = torch.tensor([0.229, 0.224, 0.225])
39
- for t in range(video_frames.shape[0]):
40
- video_frames[t] = (video_frames[t] - mean[:, None, None]) / std[:, None, None]
41
-
42
- resize_transform = Resize((224, 224))
43
- video_frames = resize_transform(video_frames)
44
- video_frames = video_frames.unsqueeze(0)
45
-
46
- return video_frames
47
- except Exception as e:
48
- raise ValueError(f"Error preprocessing video: {str(e)}")
49
-
50
- def predict_video(video):
51
- try:
52
- video_path = video.name
53
- video_data = load_video(video_path)
54
- processed_video = preprocess_video(video_data)
55
- processed_video = processed_video.to(device)
56
-
57
- with torch.no_grad():
58
- outputs = loaded_model(processed_video)
59
- logits = outputs.logits
60
- probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
61
- top_3 = torch.topk(probabilities, 3)
62
-
63
- results = []
64
- for i in range(3):
65
- idx = top_3.indices[i].item()
66
- prob = top_3.values[i].item()
67
- results.append(f"{label_names[idx]}: {prob*100:.2f}%")
68
-
69
- return "\n".join(results)
70
- except Exception as e:
71
- return f"Error processing video: {str(e)}"
72
-
73
- iface = gr.Interface(
74
- fn=predict_video,
75
- inputs=gr.Video(label="Upload Video"),
76
- outputs=gr.Textbox(label="Top 3 Predictions"),
77
- title="Video Action Recognition",
78
- description="Upload a video to classify the action being performed. The model will return the top 3 predictions with their probabilities.",
79
- examples=[
80
- ["test_video_1.avi"],
81
- ["test_video_2.avi"],
82
- ["test_video_3.avi"]
83
- ]
84
- )
85
-
86
- iface.launch(debug=True, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from pytorchvideo.data.encoded_video import EncodedVideo
4
+ from pytorchvideo.transforms import UniformTemporalSubsample
5
+ from transformers import VideoMAEForVideoClassification
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms.functional as F_t
8
+ import warnings
9
+ import os
10
+
11
+ warnings.filterwarnings('ignore', category=UserWarning)
12
+
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+
15
+ model_path = "model"
16
+ loaded_model = VideoMAEForVideoClassification.from_pretrained(model_path)
17
+ loaded_model = loaded_model.to(device)
18
+ loaded_model.eval()
19
+
20
+ label_names = [
21
+ 'Archery', 'BalanceBeam', 'BenchPress', 'ApplyEyeMakeup', 'BasketballDunk',
22
+ 'BandMarching', 'BabyCrawling', 'ApplyLipstick', 'BaseballPitch', 'Basketball'
23
+ ]
24
+
25
+ def load_video(video_path):
26
+ try:
27
+ if not os.path.exists(video_path):
28
+ raise ValueError(f"Video file not found: {video_path}")
29
+
30
+ video = EncodedVideo.from_path(video_path)
31
+ video_data = video.get_clip(start_sec=0, end_sec=video.duration)
32
+ return video_data['video']
33
+ except Exception as e:
34
+ raise ValueError(f"Error loading video: {str(e)}")
35
+
36
+ def preprocess_video(video_frames):
37
+ try:
38
+ transform_temporal = UniformTemporalSubsample(16)
39
+ video_frames = transform_temporal(video_frames)
40
+ video_frames = video_frames.float() / 255.0
41
+ if video_frames.shape[0] == 3:
42
+ video_frames = video_frames.permute(1, 0, 2, 3)
43
+ mean = torch.tensor([0.485, 0.456, 0.406])
44
+ std = torch.tensor([0.229, 0.224, 0.225])
45
+ for t in range(video_frames.shape[0]):
46
+ video_frames[t] = F_t.normalize(video_frames[t], mean, std)
47
+ video_frames = torch.stack([
48
+ F_t.resize(frame, [224, 224], antialias=True)
49
+ for frame in video_frames
50
+ ])
51
+ video_frames = video_frames.unsqueeze(0)
52
+ return video_frames
53
+ except Exception as e:
54
+ raise ValueError(f"Error preprocessing video: {str(e)}")
55
+
56
+ def predict_video(video):
57
+ if video is None:
58
+ return "Please upload a video file."
59
+
60
+ try:
61
+ video_data = load_video(video)
62
+ processed_video = preprocess_video(video_data)
63
+ processed_video = processed_video.to(device)
64
+ with torch.no_grad():
65
+ outputs = loaded_model(processed_video)
66
+ logits = outputs.logits
67
+ probabilities = F.softmax(logits, dim=-1)[0]
68
+ top_3 = torch.topk(probabilities, 3)
69
+ results = [
70
+ f"{label_names[idx.item()]}: {prob.item():.2%}"
71
+ for idx, prob in zip(top_3.indices, top_3.values)
72
+ ]
73
+ return "\n".join(results)
74
+ except Exception as e:
75
+ return f"Error processing video: {str(e)}"
76
+
77
+ iface = gr.Interface(
78
+ fn=predict_video,
79
+ inputs=gr.Video(
80
+ label="Upload Video",
81
+ format="mp4",
82
+ source="upload",
83
+ type="filepath"
84
+ ),
85
+ outputs=gr.Textbox(label="Top 3 Predictions"),
86
+ title="Video Action Recognition",
87
+ description="Upload a video to classify the action being performed. The model will return the top 3 predictions.",
88
+ examples=[
89
+ ["test_video_1.avi"],
90
+ ["test_video_2.avi"],
91
+ ["test_video_3.avi"]
92
+ ],
93
+ cache_examples=True
94
+ )
95
+
96
+ if __name__ == "__main__":
97
+ iface.launch(
98
+ debug=False,
99
+ share=False,
100
+ server_name="0.0.0.0",
101
+ server_port=7860,
102
+ ssr=False
103
+ )