Spaces:
Runtime error
Runtime error
nit
Browse files
audiocraft/models/musicgen.py
CHANGED
|
@@ -115,7 +115,7 @@ class MusicGen:
|
|
| 115 |
should we extend the audio each time. Larger values will mean less context is
|
| 116 |
preserved, and shorter value will require extra computations.
|
| 117 |
"""
|
| 118 |
-
assert extend_stride
|
| 119 |
self.extend_stride = extend_stride
|
| 120 |
self.duration = duration
|
| 121 |
self.generation_params = {
|
|
|
|
| 115 |
should we extend the audio each time. Larger values will mean less context is
|
| 116 |
preserved, and shorter value will require extra computations.
|
| 117 |
"""
|
| 118 |
+
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
| 119 |
self.extend_stride = extend_stride
|
| 120 |
self.duration = duration
|
| 121 |
self.generation_params = {
|
tests/models/test_musicgen.py
CHANGED
|
@@ -13,7 +13,7 @@ from audiocraft.models import MusicGen
|
|
| 13 |
class TestSEANetModel:
|
| 14 |
def get_musicgen(self):
|
| 15 |
mg = MusicGen.get_pretrained(name='debug', device='cpu')
|
| 16 |
-
mg.set_generation_params(duration=2.0)
|
| 17 |
return mg
|
| 18 |
|
| 19 |
def test_base(self):
|
|
@@ -51,7 +51,7 @@ class TestSEANetModel:
|
|
| 51 |
|
| 52 |
def test_generate_long(self):
|
| 53 |
mg = self.get_musicgen()
|
| 54 |
-
mg.set_generation_params(duration=4.)
|
| 55 |
wav = mg.generate(
|
| 56 |
['youpi', 'lapin dort'])
|
| 57 |
assert list(wav.shape) == [2, 1, 32000 * 4]
|
|
|
|
| 13 |
class TestSEANetModel:
|
| 14 |
def get_musicgen(self):
|
| 15 |
mg = MusicGen.get_pretrained(name='debug', device='cpu')
|
| 16 |
+
mg.set_generation_params(duration=2.0, stride_extend=2.)
|
| 17 |
return mg
|
| 18 |
|
| 19 |
def test_base(self):
|
|
|
|
| 51 |
|
| 52 |
def test_generate_long(self):
|
| 53 |
mg = self.get_musicgen()
|
| 54 |
+
mg.set_generation_params(duration=4., stride_extend=2.)
|
| 55 |
wav = mg.generate(
|
| 56 |
['youpi', 'lapin dort'])
|
| 57 |
assert list(wav.shape) == [2, 1, 32000 * 4]
|