Spaces:
Runtime error
Runtime error
Update vocos_bark.py
Browse files- vocos_bark.py +14 -19
vocos_bark.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from vocos import Vocos
|
| 2 |
from typing import Dict, Optional, Tuple, Union
|
| 3 |
|
| 4 |
from transformers.models.bark import BarkSemanticModel, BarkCoarseModel, BarkFineModel, BarkPreTrainedModel
|
|
@@ -7,11 +6,10 @@ from transformers.models.bark.generation_configuration_bark import (
|
|
| 7 |
BarkFineGenerationConfig,
|
| 8 |
BarkSemanticGenerationConfig,
|
| 9 |
)
|
| 10 |
-
from transformers import BarkConfig
|
| 11 |
from transformers.modeling_utils import get_parameter_device
|
| 12 |
from transformers.utils import (
|
| 13 |
is_accelerate_available,
|
| 14 |
-
|
| 15 |
)
|
| 16 |
|
| 17 |
import torch
|
|
@@ -26,7 +24,8 @@ class BarkModel(BarkPreTrainedModel):
|
|
| 26 |
self.coarse_acoustics = BarkCoarseModel(config.coarse_acoustics_config)
|
| 27 |
self.fine_acoustics = BarkFineModel(config.fine_acoustics_config)
|
| 28 |
|
| 29 |
-
self.
|
|
|
|
| 30 |
self.config = config
|
| 31 |
|
| 32 |
@property
|
|
@@ -81,12 +80,20 @@ class BarkModel(BarkPreTrainedModel):
|
|
| 81 |
|
| 82 |
self.fine_acoustics_hook = hook
|
| 83 |
|
| 84 |
-
_, hook = cpu_offload_with_hook(self.
|
| 85 |
|
| 86 |
# We'll offload the last model manually.
|
| 87 |
self.codec_model_hook = hook
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
|
|
|
| 90 |
|
| 91 |
@torch.no_grad()
|
| 92 |
def generate(
|
|
@@ -197,18 +204,6 @@ class BarkModel(BarkPreTrainedModel):
|
|
| 197 |
# and load codec_model to GPU
|
| 198 |
# since bark doesn't use codec_model forward pass
|
| 199 |
self.fine_acoustics_hook.offload()
|
| 200 |
-
self.
|
| 201 |
-
|
| 202 |
-
# 4. Decode the output and generate audio array
|
| 203 |
-
bandwidth_id = torch.tensor([2]).to(self.device)
|
| 204 |
-
# transpose
|
| 205 |
-
value = output.transpose(0,1)
|
| 206 |
-
value = self.vocos.codes_to_features(value)
|
| 207 |
-
value = self.vocos.decode(value, bandwidth_id=bandwidth_id)
|
| 208 |
-
|
| 209 |
-
if getattr(self, "codec_model_hook", None) is not None:
|
| 210 |
-
# Offload codec_model to CPU
|
| 211 |
-
self.vocos.offload()
|
| 212 |
-
|
| 213 |
|
| 214 |
-
return
|
|
|
|
|
|
|
| 1 |
from typing import Dict, Optional, Tuple, Union
|
| 2 |
|
| 3 |
from transformers.models.bark import BarkSemanticModel, BarkCoarseModel, BarkFineModel, BarkPreTrainedModel
|
|
|
|
| 6 |
BarkFineGenerationConfig,
|
| 7 |
BarkSemanticGenerationConfig,
|
| 8 |
)
|
| 9 |
+
from transformers import BarkConfig, AutoModel
|
| 10 |
from transformers.modeling_utils import get_parameter_device
|
| 11 |
from transformers.utils import (
|
| 12 |
is_accelerate_available,
|
|
|
|
| 13 |
)
|
| 14 |
|
| 15 |
import torch
|
|
|
|
| 24 |
self.coarse_acoustics = BarkCoarseModel(config.coarse_acoustics_config)
|
| 25 |
self.fine_acoustics = BarkFineModel(config.fine_acoustics_config)
|
| 26 |
|
| 27 |
+
self.codec_model = AutoModel.from_config(config.codec_config)
|
| 28 |
+
|
| 29 |
self.config = config
|
| 30 |
|
| 31 |
@property
|
|
|
|
| 80 |
|
| 81 |
self.fine_acoustics_hook = hook
|
| 82 |
|
| 83 |
+
_, hook = cpu_offload_with_hook(self.codec_model, device, prev_module_hook=hook)
|
| 84 |
|
| 85 |
# We'll offload the last model manually.
|
| 86 |
self.codec_model_hook = hook
|
| 87 |
|
| 88 |
+
def codec_decode(self, fine_output):
|
| 89 |
+
"""Turn quantized audio codes into audio array using encodec."""
|
| 90 |
+
|
| 91 |
+
fine_output = fine_output.transpose(0, 1)
|
| 92 |
+
emb = self.codec_model.quantizer.decode(fine_output)
|
| 93 |
+
out = self.codec_model.decoder(emb)
|
| 94 |
+
audio_arr = out.squeeze(1) # squeeze the codebook dimension
|
| 95 |
|
| 96 |
+
return audio_arr
|
| 97 |
|
| 98 |
@torch.no_grad()
|
| 99 |
def generate(
|
|
|
|
| 204 |
# and load codec_model to GPU
|
| 205 |
# since bark doesn't use codec_model forward pass
|
| 206 |
self.fine_acoustics_hook.offload()
|
| 207 |
+
self.codec_model = self.codec_model.to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
+
return output
|