Spaces:
Runtime error
Runtime error
adding support for cpu
Browse files
audiocraft/models/loaders.py
CHANGED
|
@@ -80,8 +80,6 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_di
|
|
| 80 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
| 81 |
cfg.device = str(device)
|
| 82 |
if cfg.device == 'cpu':
|
| 83 |
-
cfg.transformer_lm.memory_efficient = False
|
| 84 |
-
cfg.transformer_lm.custom = True
|
| 85 |
cfg.dtype = 'float32'
|
| 86 |
else:
|
| 87 |
cfg.dtype = 'float16'
|
|
|
|
| 80 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
| 81 |
cfg.device = str(device)
|
| 82 |
if cfg.device == 'cpu':
|
|
|
|
|
|
|
| 83 |
cfg.dtype = 'float32'
|
| 84 |
else:
|
| 85 |
cfg.dtype = 'float16'
|
audiocraft/models/musicgen.py
CHANGED
|
@@ -68,7 +68,7 @@ class MusicGen:
|
|
| 68 |
return self.compression_model.channels
|
| 69 |
|
| 70 |
@staticmethod
|
| 71 |
-
def get_pretrained(name: str = 'melody', device=
|
| 72 |
"""Return pretrained model, we provide four models:
|
| 73 |
- small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
|
| 74 |
- medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
|
|
@@ -76,11 +76,17 @@ class MusicGen:
|
|
| 76 |
- large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
|
| 77 |
"""
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
if name == 'debug':
|
| 80 |
# used only for unit tests
|
| 81 |
compression_model = get_debug_compression_model(device)
|
| 82 |
lm = get_debug_lm_model(device)
|
| 83 |
-
return MusicGen(name, compression_model, lm
|
| 84 |
|
| 85 |
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
| 86 |
raise ValueError(
|
|
@@ -313,7 +319,6 @@ class MusicGen:
|
|
| 313 |
all_tokens.append(prompt_tokens)
|
| 314 |
prompt_length = prompt_tokens.shape[-1]
|
| 315 |
|
| 316 |
-
|
| 317 |
stride_tokens = int(self.frame_rate * self.extend_stride)
|
| 318 |
|
| 319 |
while current_gen_offset + prompt_length < total_gen_len:
|
|
|
|
| 68 |
return self.compression_model.channels
|
| 69 |
|
| 70 |
@staticmethod
|
| 71 |
+
def get_pretrained(name: str = 'melody', device=None):
|
| 72 |
"""Return pretrained model, we provide four models:
|
| 73 |
- small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
|
| 74 |
- medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
|
|
|
|
| 76 |
- large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
|
| 77 |
"""
|
| 78 |
|
| 79 |
+
if device is None:
|
| 80 |
+
if torch.cuda.device_count():
|
| 81 |
+
device = 'cuda'
|
| 82 |
+
else:
|
| 83 |
+
device = 'cpu'
|
| 84 |
+
|
| 85 |
if name == 'debug':
|
| 86 |
# used only for unit tests
|
| 87 |
compression_model = get_debug_compression_model(device)
|
| 88 |
lm = get_debug_lm_model(device)
|
| 89 |
+
return MusicGen(name, compression_model, lm)
|
| 90 |
|
| 91 |
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
| 92 |
raise ValueError(
|
|
|
|
| 319 |
all_tokens.append(prompt_tokens)
|
| 320 |
prompt_length = prompt_tokens.shape[-1]
|
| 321 |
|
|
|
|
| 322 |
stride_tokens = int(self.frame_rate * self.extend_stride)
|
| 323 |
|
| 324 |
while current_gen_offset + prompt_length < total_gen_len:
|
tests/models/test_musicgen.py
CHANGED
|
@@ -51,6 +51,7 @@ class TestSEANetModel:
|
|
| 51 |
|
| 52 |
def test_generate_long(self):
|
| 53 |
mg = self.get_musicgen()
|
|
|
|
| 54 |
mg.set_generation_params(duration=4., stride_extend=2.)
|
| 55 |
wav = mg.generate(
|
| 56 |
['youpi', 'lapin dort'])
|
|
|
|
| 51 |
|
| 52 |
def test_generate_long(self):
|
| 53 |
mg = self.get_musicgen()
|
| 54 |
+
mg.max_duration = 3.
|
| 55 |
mg.set_generation_params(duration=4., stride_extend=2.)
|
| 56 |
wav = mg.generate(
|
| 57 |
['youpi', 'lapin dort'])
|