thinh-huynh-re commited on
Commit
69d29a2
·
1 Parent(s): 3d771a2
Files changed (1) hide show
  1. app.py +18 -10
app.py CHANGED
@@ -38,19 +38,12 @@ def sample_frame_indices(
38
 
39
  # @st.cache_resource
40
  @st.experimental_singleton
41
- def load_model():
42
- feature_extractor = AutoFeatureExtractor.from_pretrained(
43
- "MCG-NJU/videomae-base-finetuned-kinetics"
44
- )
45
- model = TimesformerForVideoClassification.from_pretrained(
46
- "facebook/timesformer-base-finetuned-k400"
47
- )
48
  return feature_extractor, model
49
 
50
 
51
- feature_extractor, model = load_model()
52
-
53
-
54
  def inference(file_path: str):
55
  videoreader = VideoReader(VIDEO_TMP_PATH, num_threads=1, ctx=cpu(0))
56
 
@@ -96,6 +89,21 @@ with st.expander("INTRODUCTION"):
96
  """
97
  )
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  VIDEO_TMP_PATH = os.path.join("tmp", "tmp.mp4")
100
  uploadedfile = st.file_uploader("Upload file", type=["mp4"])
101
 
 
38
 
39
  # @st.cache_resource
40
  @st.experimental_singleton
41
+ def load_model(model_name: str):
42
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
43
+ model = TimesformerForVideoClassification.from_pretrained(model_name)
 
 
 
 
44
  return feature_extractor, model
45
 
46
 
 
 
 
47
  def inference(file_path: str):
48
  videoreader = VideoReader(VIDEO_TMP_PATH, num_threads=1, ctx=cpu(0))
49
 
 
89
  """
90
  )
91
 
92
+ model_name = st.selectbox(
93
+ "model_name",
94
+ (
95
+ "facebook/timesformer-base-finetuned-k400",
96
+ "facebook/timesformer-base-finetuned-k600",
97
+ "facebook/timesformer-base-finetuned-ssv2",
98
+ "facebook/timesformer-hr-finetuned-k600",
99
+ "facebook/timesformer-hr-finetuned-k400",
100
+ "facebook/timesformer-hr-finetuned-ssv2",
101
+ "fcakyon/timesformer-large-finetuned-k400",
102
+ "fcakyon/timesformer-large-finetuned-k600",
103
+ ),
104
+ )
105
+ feature_extractor, model = load_model(model_name)
106
+
107
  VIDEO_TMP_PATH = os.path.join("tmp", "tmp.mp4")
108
  uploadedfile = st.file_uploader("Upload file", type=["mp4"])
109