Spaces:
Runtime error
Runtime error
Commit
·
69d29a2
1
Parent(s):
3d771a2
Update
Browse files
app.py
CHANGED
@@ -38,19 +38,12 @@ def sample_frame_indices(
|
|
38 |
|
39 |
# @st.cache_resource
|
40 |
@st.experimental_singleton
|
41 |
-
def load_model():
|
42 |
-
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
43 |
-
|
44 |
-
)
|
45 |
-
model = TimesformerForVideoClassification.from_pretrained(
|
46 |
-
"facebook/timesformer-base-finetuned-k400"
|
47 |
-
)
|
48 |
return feature_extractor, model
|
49 |
|
50 |
|
51 |
-
feature_extractor, model = load_model()
|
52 |
-
|
53 |
-
|
54 |
def inference(file_path: str):
|
55 |
videoreader = VideoReader(VIDEO_TMP_PATH, num_threads=1, ctx=cpu(0))
|
56 |
|
@@ -96,6 +89,21 @@ with st.expander("INTRODUCTION"):
|
|
96 |
"""
|
97 |
)
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
VIDEO_TMP_PATH = os.path.join("tmp", "tmp.mp4")
|
100 |
uploadedfile = st.file_uploader("Upload file", type=["mp4"])
|
101 |
|
|
|
38 |
|
39 |
# @st.cache_resource
|
40 |
@st.experimental_singleton
|
41 |
+
def load_model(model_name: str):
|
42 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
43 |
+
model = TimesformerForVideoClassification.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
44 |
return feature_extractor, model
|
45 |
|
46 |
|
|
|
|
|
|
|
47 |
def inference(file_path: str):
|
48 |
videoreader = VideoReader(VIDEO_TMP_PATH, num_threads=1, ctx=cpu(0))
|
49 |
|
|
|
89 |
"""
|
90 |
)
|
91 |
|
92 |
+
model_name = st.selectbox(
|
93 |
+
"model_name",
|
94 |
+
(
|
95 |
+
"facebook/timesformer-base-finetuned-k400",
|
96 |
+
"facebook/timesformer-base-finetuned-k600",
|
97 |
+
"facebook/timesformer-base-finetuned-ssv2",
|
98 |
+
"facebook/timesformer-hr-finetuned-k600",
|
99 |
+
"facebook/timesformer-hr-finetuned-k400",
|
100 |
+
"facebook/timesformer-hr-finetuned-ssv2",
|
101 |
+
"fcakyon/timesformer-large-finetuned-k400",
|
102 |
+
"fcakyon/timesformer-large-finetuned-k600",
|
103 |
+
),
|
104 |
+
)
|
105 |
+
feature_extractor, model = load_model(model_name)
|
106 |
+
|
107 |
VIDEO_TMP_PATH = os.path.join("tmp", "tmp.mp4")
|
108 |
uploadedfile = st.file_uploader("Upload file", type=["mp4"])
|
109 |
|