chong.zhang commited on
Commit
5a9eb30
·
1 Parent(s): db3e4d9
app.py CHANGED
@@ -55,9 +55,6 @@ def get_args(
55
 
56
  if args["time_start"] is None:
57
  args["time_start"] = 0.0
58
- # if args["time_end"] is None:
59
- # args["time_end"] = args["time_start"] + args["max_generate_audio_seconds"]
60
- # if args["time_start"] > args["time_end"]:
61
  args["time_end"] = args["time_start"] + args["max_generate_audio_seconds"]
62
 
63
  print(args)
 
55
 
56
  if args["time_start"] is None:
57
  args["time_start"] = 0.0
 
 
 
58
  args["time_end"] = args["time_start"] + args["max_generate_audio_seconds"]
59
 
60
  print(args)
inspiremusic/cli/inference.py CHANGED
@@ -18,13 +18,11 @@ import torchaudio
18
  import time
19
  import logging
20
  import argparse
21
-
22
- from modelscope import snapshot_download
23
  from inspiremusic.cli.inspiremusic import InspireMusic
24
  from inspiremusic.utils.file_utils import logging
25
  import torch
26
  from inspiremusic.utils.audio_utils import trim_audio, fade_out
27
- from transformers import AutoModel
28
 
29
  def set_env_variables():
30
  os.environ['PYTHONIOENCODING'] = 'UTF-8'
@@ -50,13 +48,27 @@ class InspireMusicUnified:
50
  fast: bool = False,
51
  fp16: bool = True,
52
  gpu: int = 0,
53
- result_dir: str = None):
 
54
  os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
55
 
56
  # Set model_dir or default to downloading if it doesn't exist
57
- self.model_dir = model_dir or f"iic/{model_name}"
58
- if not os.path.exists(self.model_dir):
59
- self.model_dir = snapshot_download(f"iic/{model_name}", cache_dir=self.model_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  self.sample_rate = sample_rate
62
  self.output_sample_rate = 24000 if fast else output_sample_rate
 
18
  import time
19
  import logging
20
  import argparse
21
+ import shutil
 
22
  from inspiremusic.cli.inspiremusic import InspireMusic
23
  from inspiremusic.utils.file_utils import logging
24
  import torch
25
  from inspiremusic.utils.audio_utils import trim_audio, fade_out
 
26
 
27
  def set_env_variables():
28
  os.environ['PYTHONIOENCODING'] = 'UTF-8'
 
48
  fast: bool = False,
49
  fp16: bool = True,
50
  gpu: int = 0,
51
+ result_dir: str = None,
52
+ hub: str = "modelscope"):
53
  os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
54
 
55
  # Set model_dir or default to downloading if it doesn't exist
56
+ if model_dir is None:
57
+ model_dir = f"../../pretrained_models/{model_name}"
58
+
59
+ if not os.path.isfile(f"{model_dir}/llm.pt"):
60
+ if hub == "modelscope":
61
+ from modelscope import snapshot_download
62
+ if model_name == "InspireMusic-Base":
63
+ model_dir_tmp = snapshot_download(f"iic/InspireMusic", cache_dir=model_dir)
64
+ else:
65
+ model_dir_tmp = snapshot_download(f"iic/{model_name}", cache_dir=model_dir)
66
+ elif hub == "huggingface":
67
+ from huggingface_hub import snapshot_download
68
+ model_dir_tmp = snapshot_download(repo_id=f"FunAudioLLM/{model_name}", cache_dir=model_dir)
69
+ shutil.move(model_dir_tmp, model_dir)
70
+ print(model_dir_tmp, model_dir)
71
+ self.model_dir = model_dir
72
 
73
  self.sample_rate = sample_rate
74
  self.output_sample_rate = 24000 if fast else output_sample_rate
inspiremusic/cli/inspiremusic.py CHANGED
@@ -15,18 +15,37 @@ import os
15
  import time
16
  from tqdm import tqdm
17
  from hyperpyyaml import load_hyperpyyaml
18
- from modelscope import snapshot_download
19
  from inspiremusic.cli.frontend import InspireMusicFrontEnd
20
  from inspiremusic.cli.model import InspireMusicModel
21
  from inspiremusic.utils.file_utils import logging
22
  import torch
 
23
 
24
  class InspireMusic:
25
- def __init__(self, model_dir, load_jit=True, load_onnx=False, fast = False, fp16=True):
26
  instruct = True if '-Instruct' in model_dir else False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  self.model_dir = model_dir
28
- if not os.path.exists(model_dir):
29
- model_dir = snapshot_download(model_dir)
 
30
  with open('{}/inspiremusic.yaml'.format(model_dir), 'r') as f:
31
  configs = load_hyperpyyaml(f)
32
 
 
15
  import time
16
  from tqdm import tqdm
17
  from hyperpyyaml import load_hyperpyyaml
 
18
  from inspiremusic.cli.frontend import InspireMusicFrontEnd
19
  from inspiremusic.cli.model import InspireMusicModel
20
  from inspiremusic.utils.file_utils import logging
21
  import torch
22
+ import shutil
23
 
24
  class InspireMusic:
25
+ def __init__(self, model_dir, load_jit=True, load_onnx=False, fast = False, fp16=True, hub="modelscope"):
26
  instruct = True if '-Instruct' in model_dir else False
27
+
28
+ if model_dir is None:
29
+ model_dir = f"../../pretrained_models/InspireMusic-1.5B-Long"
30
+
31
+ if not os.path.isfile(f"{model_dir}/llm.pt"):
32
+ model_name = model_dir.split("/")[-1]
33
+ if hub == "modelscope":
34
+ from modelscope import snapshot_download
35
+ if model_name == "InspireMusic-Base":
36
+ model_dir_tmp = snapshot_download(f"iic/InspireMusic", cache_dir=model_dir)
37
+ else:
38
+ model_dir_tmp = snapshot_download(f"iic/{model_name}", cache_dir=model_dir)
39
+ elif hub == "huggingface":
40
+ from huggingface_hub import snapshot_download
41
+ model_dir_tmp = snapshot_download(repo_id=f"FunAudioLLM/{model_name}", cache_dir=model_dir)
42
+ shutil.move(model_dir_tmp, model_dir)
43
+ print(model_dir_tmp, model_dir)
44
+
45
  self.model_dir = model_dir
46
+
47
+ assert os.path.exists(f'{model_dir}/inspiremusic.yaml')
48
+
49
  with open('{}/inspiremusic.yaml'.format(model_dir), 'r') as f:
50
  configs = load_hyperpyyaml(f)
51