HoneyTian commited on
Commit
e6a1b81
·
1 Parent(s): 2090c7e
examples/vm_sound_classification/step_3_train_model.py CHANGED
@@ -162,7 +162,7 @@ def main():
162
  )
163
 
164
  # models
165
- logger.info("prepare models")
166
  config = CnnAudioClassifierConfig.from_pretrained(
167
  pretrained_model_name_or_path=args.config_file,
168
  # num_labels=vocabulary.get_vocab_size(namespace="labels")
 
162
  )
163
 
164
  # models
165
+ logger.info(f"prepare models. config_file: {args.config_file}")
166
  config = CnnAudioClassifierConfig.from_pretrained(
167
  pretrained_model_name_or_path=args.config_file,
168
  # num_labels=vocabulary.get_vocab_size(namespace="labels")
main.py CHANGED
@@ -26,22 +26,27 @@ def get_args():
26
  default=(project_path / "data/examples").as_posix(),
27
  type=str
28
  )
 
 
 
 
 
29
  parser.add_argument(
30
  "--trained_model_dir",
31
  default=(project_path / "trained_models").as_posix(),
32
  type=str
33
  )
 
 
 
 
 
34
  parser.add_argument(
35
  "--server_port",
36
  default=environment.get("server_port", 7860),
37
  type=int
38
  )
39
 
40
- parser.add_argument(
41
- "--models_repo_id",
42
- default="qgyd2021/vm_sound_classification",
43
- type=str
44
- )
45
  args = parser.parse_args()
46
  return args
47
 
@@ -118,7 +123,8 @@ def main():
118
  trained_model_dir.mkdir(parents=True, exist_ok=True)
119
  _ = snapshot_download(
120
  repo_id=args.models_repo_id,
121
- local_dir=trained_model_dir.as_posix()
 
122
  )
123
 
124
  # examples
 
26
  default=(project_path / "data/examples").as_posix(),
27
  type=str
28
  )
29
+ parser.add_argument(
30
+ "--models_repo_id",
31
+ default="qgyd2021/vm_sound_classification",
32
+ type=str
33
+ )
34
  parser.add_argument(
35
  "--trained_model_dir",
36
  default=(project_path / "trained_models").as_posix(),
37
  type=str
38
  )
39
+ parser.add_argument(
40
+ "--hf_token",
41
+ default=environment.get("hf_token"),
42
+ type=str,
43
+ )
44
  parser.add_argument(
45
  "--server_port",
46
  default=environment.get("server_port", 7860),
47
  type=int
48
  )
49
 
 
 
 
 
 
50
  args = parser.parse_args()
51
  return args
52
 
 
123
  trained_model_dir.mkdir(parents=True, exist_ok=True)
124
  _ = snapshot_download(
125
  repo_id=args.models_repo_id,
126
+ local_dir=trained_model_dir.as_posix(),
127
+ token=args.hf_token,
128
  )
129
 
130
  # examples