Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2022 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import unittest | |
from unittest import skip | |
from transformers import is_torch_available | |
from transformers.testing_utils import require_torch, slow | |
from transformers.trainer_utils import set_seed | |
if is_torch_available(): | |
import torch | |
from transformers import JukeboxModel, JukeboxPrior, JukeboxTokenizer | |
class Jukebox1bModelTester(unittest.TestCase): | |
all_model_classes = (JukeboxModel,) if is_torch_available() else () | |
model_id = "openai/jukebox-1b-lyrics" | |
metas = { | |
"artist": "Zac Brown Band", | |
"genres": "Country", | |
"lyrics": """I met a traveller from an antique land, | |
Who said "Two vast and trunkless legs of stone | |
Stand in the desert. . . . Near them, on the sand, | |
Half sunk a shattered visage lies, whose frown, | |
And wrinkled lip, and sneer of cold command, | |
Tell that its sculptor well those passions read | |
Which yet survive, stamped on these lifeless things, | |
The hand that mocked them, and the heart that fed; | |
And on the pedestal, these words appear: | |
My name is Ozymandias, King of Kings; | |
Look on my Works, ye Mighty, and despair! | |
Nothing beside remains. Round the decay | |
Of that colossal Wreck, boundless and bare | |
The lone and level sands stretch far away | |
""", | |
} | |
# fmt: off | |
EXPECTED_OUTPUT_2 = [ | |
1864, 1536, 1213, 1870, 1357, 1536, 519, 880, 1323, 789, 1082, 534, | |
1000, 1445, 1105, 1130, 967, 515, 1434, 1620, 534, 1495, 283, 1445, | |
333, 1307, 539, 1631, 1528, 375, 1434, 673, 627, 710, 778, 1883, | |
1405, 1276, 1455, 1228 | |
] | |
EXPECTED_OUTPUT_2_PT_2 = [ | |
1489, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
653, 653, 653, 653 | |
] | |
EXPECTED_OUTPUT_1 = [ | |
1125, 1751, 697, 1776, 1141, 1476, 391, 697, 1125, 684, 867, 416, | |
844, 1372, 1274, 717, 1274, 844, 1299, 1419, 697, 1370, 317, 1125, | |
191, 1440, 1370, 1440, 1370, 282, 1621, 1370, 368, 349, 867, 1872, | |
1262, 869, 1728, 747 | |
] | |
EXPECTED_OUTPUT_1_PT_2 = [ | |
416, 416, 1125, 1125, 416, 416, 416, 416, 416, 416, 416, 416, | |
416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, | |
416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, | |
416, 416, 416, 416 | |
] | |
EXPECTED_OUTPUT_0 = [ | |
1755, 842, 307, 1843, 1022, 1395, 234, 1554, 806, 739, 1022, 442, | |
616, 556, 268, 1499, 933, 457, 1440, 1837, 755, 985, 308, 902, | |
293, 1443, 1671, 1141, 1533, 555, 1562, 1061, 287, 417, 1022, 2008, | |
1186, 1015, 1777, 268 | |
] | |
EXPECTED_OUTPUT_0_PT_2 = [ | |
854, 842, 1353, 114, 1353, 842, 185, 842, 185, 114, 591, 842, | |
185, 417, 185, 842, 307, 842, 591, 842, 185, 842, 307, 842, | |
591, 842, 1353, 842, 185, 842, 591, 842, 591, 114, 591, 842, | |
185, 842, 591, 89 | |
] | |
EXPECTED_Y_COND = [1058304, 0, 786432, 7169, 507, 76, 27, 40, 30, 76] | |
EXPECTED_PRIMED_0 = [ | |
390, 1160, 1002, 1907, 1788, 1788, 1788, 1907, 1002, 1002, 1854, 1002, | |
1002, 1002, 1002, 1002, 1002, 1160, 1160, 1606, 596, 596, 1160, 1002, | |
1516, 596, 1002, 1002, 1002, 1907, 1788, 1788, 1788, 1854, 1788, 1907, | |
1907, 1788, 596, 1626 | |
] | |
EXPECTED_PRIMED_1 = [ | |
1236, 1668, 1484, 1920, 1848, 1409, 139, 864, 1828, 1272, 1599, 824, | |
1672, 139, 555, 1484, 824, 1920, 555, 596, 1579, 1599, 1231, 1599, | |
1637, 1407, 212, 824, 1599, 116, 1433, 824, 258, 1599, 1433, 1895, | |
1063, 1433, 1433, 1599 | |
] | |
EXPECTED_PRIMED_2 = [ | |
1684, 1873, 1119, 1189, 395, 611, 1901, 972, 890, 1337, 1392, 1927, | |
96, 972, 672, 780, 1119, 890, 158, 771, 1073, 1927, 353, 1331, | |
1269, 1459, 1333, 1645, 812, 1577, 1337, 606, 353, 981, 1466, 619, | |
197, 391, 302, 1930 | |
] | |
EXPECTED_VQVAE_ENCODE = [ | |
390, 1160, 1002, 1907, 1788, 1788, 1788, 1907, 1002, 1002, 1854, 1002, | |
1002, 1002, 1002, 1002, 1002, 1160, 1160, 1606, 596, 596, 1160, 1002, | |
1516, 596, 1002, 1002, 1002, 1907, 1788, 1788, 1788, 1854, 1788, 1907, | |
1907, 1788, 596, 1626 | |
] | |
EXPECTED_VQVAE_DECODE = [ | |
-0.0492, -0.0524, -0.0565, -0.0640, -0.0686, -0.0684, -0.0677, -0.0664, | |
-0.0605, -0.0490, -0.0330, -0.0168, -0.0083, -0.0075, -0.0051, 0.0025, | |
0.0136, 0.0261, 0.0386, 0.0497, 0.0580, 0.0599, 0.0583, 0.0614, | |
0.0740, 0.0889, 0.1023, 0.1162, 0.1211, 0.1212, 0.1251, 0.1336, | |
0.1502, 0.1686, 0.1883, 0.2148, 0.2363, 0.2458, 0.2507, 0.2531 | |
] | |
EXPECTED_AUDIO_COND = [ | |
0.0256, -0.0544, 0.1600, -0.0032, 0.1066, 0.0825, -0.0013, 0.3440, | |
0.0210, 0.0412, -0.1777, -0.0892, -0.0164, 0.0285, -0.0613, -0.0617, | |
-0.0137, -0.0201, -0.0175, 0.0215, -0.0627, 0.0520, -0.0730, 0.0970, | |
-0.0100, 0.0442, -0.0586, 0.0207, -0.0015, -0.0082 | |
] | |
EXPECTED_META_COND = [ | |
0.0415, 0.0877, 0.0022, -0.0055, 0.0751, 0.0334, 0.0324, -0.0068, | |
0.0011, 0.0017, -0.0676, 0.0655, -0.0143, 0.0399, 0.0303, 0.0743, | |
-0.0168, -0.0394, -0.1113, 0.0124, 0.0442, 0.0267, -0.0003, -0.1536, | |
-0.0116, -0.1837, -0.0180, -0.1026, -0.0777, -0.0456 | |
] | |
EXPECTED_LYRIC_COND = [ | |
76, 27, 40, 30, 76, 46, 44, 47, 40, 37, 38, 31, 45, 45, 76, 38, 31, 33, | |
45, 76, 41, 32, 76, 45, 46, 41, 40, 31, 78, 76 | |
] | |
# fmt: on | |
def prepare_inputs(self): | |
tokenizer = JukeboxTokenizer.from_pretrained(self.model_id) | |
tokens = tokenizer(**self.metas)["input_ids"] | |
return tokens | |
def test_sampling(self): | |
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() | |
labels = self.prepare_inputs() | |
set_seed(0) | |
zs = [torch.zeros(1, 0, dtype=torch.long).cpu() for _ in range(3)] | |
zs = model._sample(zs, labels, [0], sample_length=40 * model.priors[0].raw_to_tokens, save_results=False) | |
self.assertIn(zs[0][0].detach().cpu().tolist(), [self.EXPECTED_OUTPUT_2, self.EXPECTED_OUTPUT_2_PT_2]) | |
set_seed(0) | |
zs = model._sample(zs, labels, [1], sample_length=40 * model.priors[1].raw_to_tokens, save_results=False) | |
self.assertIn(zs[1][0].detach().cpu().tolist(), [self.EXPECTED_OUTPUT_1, self.EXPECTED_OUTPUT_1_PT_2]) | |
set_seed(0) | |
zs = model._sample(zs, labels, [2], sample_length=40 * model.priors[2].raw_to_tokens, save_results=False) | |
self.assertIn(zs[2][0].detach().cpu().tolist(), [self.EXPECTED_OUTPUT_0, self.EXPECTED_OUTPUT_0_PT_2]) | |
def test_conditioning(self): | |
torch.backends.cuda.matmul.allow_tf32 = False | |
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() | |
labels = self.prepare_inputs() | |
set_seed(0) | |
zs = [torch.zeros(1, 0, dtype=torch.long) for _ in range(3)] | |
top_prior = model.priors[0] | |
start = 0 | |
music_token_conds = top_prior.get_music_tokens_conds(zs, start=start, end=start + top_prior.n_ctx) | |
metadata = top_prior.get_metadata(labels[0].clone(), start, 1058304, 0) | |
self.assertIsNone(music_token_conds) | |
self.assertListEqual(metadata.numpy()[0][:10].tolist(), self.EXPECTED_Y_COND) | |
audio_conditioning, metadata_conditioning, lyric_tokens = top_prior.get_cond(music_token_conds, metadata) | |
torch.testing.assert_allclose( | |
audio_conditioning[0][0][:30].detach(), torch.tensor(self.EXPECTED_AUDIO_COND), atol=1e-4, rtol=1e-4 | |
) | |
torch.testing.assert_allclose( | |
metadata_conditioning[0][0][:30].detach(), torch.tensor(self.EXPECTED_META_COND), atol=1e-4, rtol=1e-4 | |
) | |
torch.testing.assert_allclose( | |
lyric_tokens[0, :30].detach(), torch.tensor(self.EXPECTED_LYRIC_COND), atol=1e-4, rtol=1e-4 | |
) | |
def test_primed_sampling(self): | |
torch.backends.cuda.matmul.allow_tf32 = False | |
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() | |
set_seed(0) | |
waveform = torch.rand((1, 5120, 1)) | |
tokens = list(self.prepare_inputs()) | |
zs = [model.vqvae.encode(waveform, start_level=2, bs_chunks=waveform.shape[0])[0], None, None] | |
zs = model._sample( | |
zs, tokens, sample_levels=[0], save_results=False, sample_length=40 * model.priors[0].raw_to_tokens | |
) | |
torch.testing.assert_allclose(zs[0][0][:40], torch.tensor(self.EXPECTED_PRIMED_0)) | |
upper_2 = torch.cat((zs[0], torch.zeros(1, 2048 - zs[0].shape[-1])), dim=-1).long() | |
zs = [upper_2, model.vqvae.encode(waveform, start_level=1, bs_chunks=waveform.shape[0])[0], None] | |
zs = model._sample( | |
zs, tokens, sample_levels=[1], save_results=False, sample_length=40 * model.priors[1].raw_to_tokens | |
) | |
torch.testing.assert_allclose(zs[1][0][:40], torch.tensor(self.EXPECTED_PRIMED_1)) | |
upper_1 = torch.cat((zs[1], torch.zeros(1, 2048 - zs[1].shape[-1])), dim=-1).long() | |
zs = [upper_2, upper_1, model.vqvae.encode(waveform, start_level=0, bs_chunks=waveform.shape[0])[0]] | |
zs = model._sample( | |
zs, tokens, sample_levels=[2], save_results=False, sample_length=40 * model.priors[2].raw_to_tokens | |
) | |
torch.testing.assert_allclose(zs[2][0][:40].cpu(), torch.tensor(self.EXPECTED_PRIMED_2)) | |
def test_vqvae(self): | |
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() | |
set_seed(0) | |
x = torch.rand((1, 5120, 1)) | |
with torch.no_grad(): | |
zs = model.vqvae.encode(x, start_level=2, bs_chunks=x.shape[0]) | |
torch.testing.assert_allclose(zs[0][0], torch.tensor(self.EXPECTED_VQVAE_ENCODE)) | |
with torch.no_grad(): | |
x = model.vqvae.decode(zs, start_level=2, bs_chunks=x.shape[0]) | |
torch.testing.assert_allclose(x[0, :40, 0], torch.tensor(self.EXPECTED_VQVAE_DECODE), atol=1e-4, rtol=1e-4) | |
class Jukebox5bModelTester(unittest.TestCase): | |
all_model_classes = (JukeboxModel,) if is_torch_available() else () | |
model_id = "openai/jukebox-5b-lyrics" | |
metas = { | |
"artist": "Zac Brown Band", | |
"genres": "Country", | |
"lyrics": """I met a traveller from an antique land, | |
Who said "Two vast and trunkless legs of stone | |
Stand in the desert. . . . Near them, on the sand, | |
Half sunk a shattered visage lies, whose frown, | |
And wrinkled lip, and sneer of cold command, | |
Tell that its sculptor well those passions read | |
Which yet survive, stamped on these lifeless things, | |
The hand that mocked them, and the heart that fed; | |
And on the pedestal, these words appear: | |
My name is Ozymandias, King of Kings; | |
Look on my Works, ye Mighty, and despair! | |
Nothing beside remains. Round the decay | |
Of that colossal Wreck, boundless and bare | |
The lone and level sands stretch far away | |
""", | |
} | |
# fmt: off | |
EXPECTED_OUTPUT_2 = [ | |
1489, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
1489, 1489, 1489, 1489, 1150, 1853, 1509, 1150, 1357, 1509, 6, 1272 | |
] | |
EXPECTED_OUTPUT_2_PT_2 = [ | |
1489, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653 | |
] | |
EXPECTED_OUTPUT_1 = [ | |
1125, 416, 1125, 1125, 1125, 1125, 1125, 416, 416, 416, 416, 416, | |
416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, | |
416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, | |
416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, | |
416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416 | |
] | |
EXPECTED_OUTPUT_1_PT_2 = [ | |
416, 416, 1125, 1125, 416, 416, 416, 416, 416, 416, 416, 416, 416, | |
416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, | |
416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, | |
416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, | |
416, 416, 416, 416, 416, 416, 416, 416 | |
] | |
EXPECTED_OUTPUT_0 = [ | |
1755, 1061, 234, 1755, 1061, 1755, 185, 290, 307, 307, 616, 616, | |
616, 616, 616, 616, 307, 290, 417, 1755, 234, 1755, 185, 290, | |
290, 290, 307, 616, 616, 616, 616, 616, 290, 234, 234, 1755, | |
234, 234, 1755, 234, 185, 185, 307, 616, 616, 616, 616, 290, | |
1755, 1755, 1755, 234, 234, 1755, 1572, 290, 307, 616, 34, 616 | |
] | |
EXPECTED_OUTPUT_0_PT_2 = [ | |
854, 842, 1353, 114, 1353, 842, 185, 842, 185, 114, 591, 842, 185, | |
417, 185, 842, 307, 842, 591, 842, 185, 842, 185, 842, 591, 842, | |
1353, 842, 185, 842, 591, 842, 591, 114, 591, 842, 185, 842, 591, | |
89, 591, 842, 591, 842, 591, 417, 1372, 842, 1372, 842, 34, 842, | |
185, 89, 591, 842, 185, 842, 591, 632 | |
] | |
EXPECTED_GPU_OUTPUTS_2 = [ | |
1489, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653 | |
] | |
EXPECTED_GPU_OUTPUTS_2_PT_2 = [ | |
1489, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, 653, | |
653, 653, 653, 653, 653, 653, 653, 1853, 1177, 1536, 1228, | |
710, 475, 1489, 1229, 1224, 231, 1224, 252, 1434, 653, 475, | |
1106, 1877, 1599, 1228, 1600, 1683, 1182, 1853, 475, 1864, | |
252, 1229, 1434, 2001 | |
] | |
EXPECTED_GPU_OUTPUTS_1 = [ | |
1125, 1125, 416, 1125, 1125, 416, 1125, 1125, 416, 416, 1125, 416, | |
416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, | |
416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, | |
416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, | |
416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416, 416 | |
] | |
EXPECTED_GPU_OUTPUTS_0 = [ | |
491, 1755, 34, 1613, 1755, 417, 992, 1613, 222, 842, 1353, 1613, | |
844, 632, 185, 1613, 844, 632, 185, 1613, 185, 842, 677, 1613, | |
185, 114, 1353, 1613, 307, 89, 844, 1613, 307, 1332, 234, 1979, | |
307, 89, 1353, 616, 34, 842, 185, 842, 34, 842, 185, 842, | |
307, 114, 185, 89, 34, 1268, 185, 89, 34, 842, 185, 89 | |
] | |
# fmt: on | |
def prepare_inputs(self, model_id): | |
tokenizer = JukeboxTokenizer.from_pretrained(model_id) | |
tokens = tokenizer(**self.metas)["input_ids"] | |
return tokens | |
def test_sampling(self): | |
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() | |
labels = self.prepare_inputs(self.model_id) | |
set_seed(0) | |
zs = [torch.zeros(1, 0, dtype=torch.long).cpu() for _ in range(3)] | |
zs = model._sample(zs, labels, [0], sample_length=60 * model.priors[0].raw_to_tokens, save_results=False) | |
self.assertIn(zs[0][0].detach().cpu().tolist(), [self.EXPECTED_OUTPUT_2, self.EXPECTED_OUTPUT_2_PT_2]) | |
set_seed(0) | |
zs = model._sample(zs, labels, [1], sample_length=60 * model.priors[1].raw_to_tokens, save_results=False) | |
self.assertIn(zs[1][0].detach().cpu().tolist(), [self.EXPECTED_OUTPUT_1, self.EXPECTED_OUTPUT_1_PT_2]) | |
set_seed(0) | |
zs = model._sample(zs, labels, [2], sample_length=60 * model.priors[2].raw_to_tokens, save_results=False) | |
self.assertIn(zs[2][0].detach().cpu().tolist(), [self.EXPECTED_OUTPUT_0, self.EXPECTED_OUTPUT_0_PT_2]) | |
def test_slow_sampling(self): | |
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() | |
labels = [i.cuda() for i in self.prepare_inputs(self.model_id)] | |
set_seed(0) | |
model.priors[0].cuda() | |
zs = [torch.zeros(1, 0, dtype=torch.long).cuda() for _ in range(3)] | |
zs = model._sample(zs, labels, [0], sample_length=60 * model.priors[0].raw_to_tokens, save_results=False) | |
torch.testing.assert_allclose(zs[0][0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_2)) | |
model.priors[0].cpu() | |
set_seed(0) | |
model.priors[1].cuda() | |
zs = model._sample(zs, labels, [1], sample_length=60 * model.priors[1].raw_to_tokens, save_results=False) | |
torch.testing.assert_allclose(zs[1][0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_1)) | |
model.priors[1].cpu() | |
set_seed(0) | |
model.priors[2].cuda() | |
zs = model._sample(zs, labels, [2], sample_length=60 * model.priors[2].raw_to_tokens, save_results=False) | |
torch.testing.assert_allclose(zs[2][0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_0)) | |
def test_fp16_slow_sampling(self): | |
prior_id = "ArthurZ/jukebox_prior_0" | |
model = JukeboxPrior.from_pretrained(prior_id, min_duration=0).eval().half().to("cuda") | |
labels = self.prepare_inputs(prior_id)[0].cuda() | |
metadata = model.get_metadata(labels, 0, 7680, 0) | |
set_seed(0) | |
outputs = model.sample(1, metadata=metadata, sample_tokens=60) | |
self.assertIn(outputs[0].cpu().tolist(), [self.EXPECTED_GPU_OUTPUTS_2, self.EXPECTED_GPU_OUTPUTS_2_PT_2]) | |