chong.zhang commited on
Commit
4e5c199
·
1 Parent(s): 06de4df
Files changed (2) hide show
  1. app.py +1 -1
  2. inspiremusic/cli/inference.py +5 -4
app.py CHANGED
@@ -43,7 +43,7 @@ def get_args(
43
  "output_sample_rate" : output_sample_rate,
44
  "min_generate_audio_seconds": 10.0,
45
  "max_generate_audio_seconds": max_generate_audio_seconds,
46
- "model_dir" : os.path.join("iic",
47
  model_name),
48
  "result_dir" : "exp/inspiremusic",
49
  "output_fn" : generate_filename(),
 
43
  "output_sample_rate" : output_sample_rate,
44
  "min_generate_audio_seconds": 10.0,
45
  "max_generate_audio_seconds": max_generate_audio_seconds,
46
+ "model_dir" : os.path.join("pretrained_models",
47
  model_name),
48
  "result_dir" : "exp/inspiremusic",
49
  "output_fn" : generate_filename(),
inspiremusic/cli/inference.py CHANGED
@@ -53,19 +53,20 @@ class InspireMusicUnified:
53
  os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
54
 
55
  # Set model_dir or default to downloading if it doesn't exist
 
56
  if model_dir is None:
57
- model_dir = f"../../pretrained_models/{model_name}"
58
  print(model_dir)
59
  if not os.path.isfile(f"{model_dir}/llm.pt"):
60
  if hub == "modelscope":
61
  from modelscope import snapshot_download
62
  if model_name == "InspireMusic-Base":
63
- model_dir_tmp = snapshot_download(f"iic/InspireMusic", cache_dir=model_dir)
64
  else:
65
- model_dir_tmp = snapshot_download(f"iic/{model_name}", cache_dir=model_dir)
66
  elif hub == "huggingface":
67
  from huggingface_hub import snapshot_download
68
- model_dir_tmp = snapshot_download(repo_id=f"FunAudioLLM/{model_name}", cache_dir=model_dir)
69
  print(model_dir_tmp, model_dir)
70
  shutil.move(model_dir_tmp, model_dir)
71
  # shutil.rmtree(model_dir_tmp)
 
53
  os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
54
 
55
  # Set model_dir or default to downloading if it doesn't exist
56
+ download_model_dir = "../../pretrained_models"
57
  if model_dir is None:
58
+ model_dir = f"{download_model_dir}/{model_name}"
59
  print(model_dir)
60
  if not os.path.isfile(f"{model_dir}/llm.pt"):
61
  if hub == "modelscope":
62
  from modelscope import snapshot_download
63
  if model_name == "InspireMusic-Base":
64
+ model_dir_tmp = snapshot_download(f"iic/InspireMusic", cache_dir=download_model_dir)
65
  else:
66
+ model_dir_tmp = snapshot_download(f"iic/{model_name}", cache_dir=download_model_dir)
67
  elif hub == "huggingface":
68
  from huggingface_hub import snapshot_download
69
+ model_dir_tmp = snapshot_download(repo_id=f"FunAudioLLM/{model_name}", cache_dir=download_model_dir)
70
  print(model_dir_tmp, model_dir)
71
  shutil.move(model_dir_tmp, model_dir)
72
  # shutil.rmtree(model_dir_tmp)