AudioGen dtypes
Browse files- audiocraft/builders.py +64 -108
- audiocraft/conditioners.py +0 -71
- audiocraft/lm.py +76 -23
- audiocraft/seanet.py +29 -14
- audiocraft/transformer.py +25 -23
- audiocraft/vq.py +16 -35
audiocraft/builders.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1 |
import omegaconf
|
2 |
import torchaudio
|
3 |
import torch
|
|
|
4 |
import numpy as np
|
5 |
from huggingface_hub import hf_hub_download
|
6 |
import os
|
7 |
from omegaconf import OmegaConf
|
8 |
from .encodec import EncodecModel
|
9 |
-
from .lm import LMModel
|
10 |
from .seanet import SEANetDecoder
|
11 |
from .vq import ResidualVectorQuantizer
|
12 |
|
|
|
13 |
N_REPEAT = 4 # num (virtual batch_size) clones of audio sounds
|
14 |
|
15 |
def _shift(x):
|
@@ -29,13 +31,7 @@ def _delete_param(cfg, full_name):
|
|
29 |
if parts[-1] in cfg:
|
30 |
del cfg[parts[-1]]
|
31 |
OmegaConf.set_struct(cfg, True)
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
def dict_from_config(cfg):
|
36 |
-
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
|
37 |
-
return dct
|
38 |
-
|
39 |
|
40 |
class AudioGen(torch.nn.Module):
|
41 |
|
@@ -44,21 +40,72 @@ class AudioGen(torch.nn.Module):
|
|
44 |
def __init__(self):
|
45 |
|
46 |
super().__init__()
|
47 |
-
self.
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
def generate(self,
|
54 |
-
|
55 |
duration=2.24, ## seconds of audio
|
56 |
):
|
57 |
|
58 |
with torch.no_grad():
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
62 |
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
|
63 |
|
64 |
x = x[:, 0, :] # last samples have splash sounds DISCARD 25000 last samples
|
@@ -74,94 +121,3 @@ class AudioGen(torch.nn.Module):
|
|
74 |
|
75 |
print(x.abs().max(), 'MAX')
|
76 |
return x / (x.abs().max() + 1e-7)
|
77 |
-
|
78 |
-
def get_quantizer(self, quantizer, cfg, dimension):
|
79 |
-
klass = {
|
80 |
-
'no_quant': None,
|
81 |
-
'rvq': ResidualVectorQuantizer
|
82 |
-
}[quantizer]
|
83 |
-
kwargs = dict_from_config(getattr(cfg, quantizer))
|
84 |
-
if quantizer != 'no_quant':
|
85 |
-
kwargs['dimension'] = dimension
|
86 |
-
return klass(**kwargs)
|
87 |
-
|
88 |
-
def get_encodec_autoencoder(self, cfg):
|
89 |
-
kwargs = dict_from_config(getattr(cfg, 'seanet'))
|
90 |
-
_ = kwargs.pop('encoder')
|
91 |
-
decoder_override_kwargs = kwargs.pop('decoder')
|
92 |
-
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
|
93 |
-
decoder = SEANetDecoder(**decoder_kwargs)
|
94 |
-
return decoder
|
95 |
-
|
96 |
-
def get_compression_model(self, cfg):
|
97 |
-
"""Instantiate a compression model."""
|
98 |
-
if cfg.compression_model == 'encodec':
|
99 |
-
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
100 |
-
quantizer_name = kwargs.pop('quantizer')
|
101 |
-
decoder = self.get_encodec_autoencoder(cfg)
|
102 |
-
quantizer = self.get_quantizer(quantizer_name, cfg, 128)
|
103 |
-
renormalize = kwargs.pop('renormalize', False)
|
104 |
-
# deprecated params
|
105 |
-
# print(f'{frame_rate=} {encoder.dimension=}') frame_rate=50 encoder.dimension=128
|
106 |
-
kwargs.pop('renorm', None)
|
107 |
-
# print('\n______!____________\n', kwargs, '\n______!____________\n')
|
108 |
-
# ______!____________
|
109 |
-
# {'autoencoder': 'seanet', 'sample_rate': 16000, 'channels': 1, 'causal': False}
|
110 |
-
# ______!____________
|
111 |
-
|
112 |
-
return EncodecModel(decoder=decoder,
|
113 |
-
quantizer=quantizer,
|
114 |
-
frame_rate=50,
|
115 |
-
renormalize=renormalize,
|
116 |
-
sample_rate=16000,
|
117 |
-
channels=1,
|
118 |
-
causal=False
|
119 |
-
).to(cfg.device)
|
120 |
-
else:
|
121 |
-
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
122 |
-
|
123 |
-
def load_compression_model(self):
|
124 |
-
file = hf_hub_download(
|
125 |
-
repo_id='facebook/audiogen-medium',
|
126 |
-
filename="compression_state_dict.bin",
|
127 |
-
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
|
128 |
-
library_name="audiocraft",
|
129 |
-
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
|
130 |
-
pkg = torch.load(file, map_location='cpu')
|
131 |
-
# if 'pretrained' in pkg:
|
132 |
-
# print('NO RPtrained\n=\n=\n=\n=\n=')
|
133 |
-
# return EncodecModel.get_pretrained(pkg['pretrained'], device='cpu')
|
134 |
-
cfg = OmegaConf.create(pkg['xp.cfg'])
|
135 |
-
cfg.device = 'cpu'
|
136 |
-
model = self.get_compression_model(cfg)
|
137 |
-
model.load_state_dict(pkg['best_state'], strict=False) # ckpt has also unused encoder weights
|
138 |
-
# return model
|
139 |
-
self.compression_model = model
|
140 |
-
|
141 |
-
def load_lm_model(self):
|
142 |
-
file = hf_hub_download(
|
143 |
-
repo_id='facebook/audiogen-medium',
|
144 |
-
filename="state_dict.bin",
|
145 |
-
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
|
146 |
-
library_name="audiocraft",
|
147 |
-
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
|
148 |
-
pkg = torch.load(file,
|
149 |
-
map_location='cpu')
|
150 |
-
cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
|
151 |
-
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
|
152 |
-
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
153 |
-
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
154 |
-
print('___________________________CFG___________________',cfg,'\n=======================')
|
155 |
-
kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
|
156 |
-
print('___________________________Kwarg___________________',kwargs,'\n=======================')
|
157 |
-
model = LMModel().to(getattr(torch, cfg.dtype)) #.to(cfg.device)
|
158 |
-
_best = pkg['best_state']
|
159 |
-
_best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
|
160 |
-
_best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
|
161 |
-
model.load_state_dict(pkg['best_state'], strict=True)
|
162 |
-
# model.cfg = cfg
|
163 |
-
self.lm = model.to(torch.float)
|
164 |
-
|
165 |
-
# def _flush(self):
|
166 |
-
# self.lm._flush() # already done in lm generate at end
|
167 |
-
|
|
|
1 |
import omegaconf
|
2 |
import torchaudio
|
3 |
import torch
|
4 |
+
from torch import nn
|
5 |
import numpy as np
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
import os
|
8 |
from omegaconf import OmegaConf
|
9 |
from .encodec import EncodecModel
|
10 |
+
from .lm import LMModel, TorchAutocast
|
11 |
from .seanet import SEANetDecoder
|
12 |
from .vq import ResidualVectorQuantizer
|
13 |
|
14 |
+
|
15 |
N_REPEAT = 4 # num (virtual batch_size) clones of audio sounds
|
16 |
|
17 |
def _shift(x):
|
|
|
31 |
if parts[-1] in cfg:
|
32 |
del cfg[parts[-1]]
|
33 |
OmegaConf.set_struct(cfg, True)
|
34 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
class AudioGen(torch.nn.Module):
|
37 |
|
|
|
40 |
def __init__(self):
|
41 |
|
42 |
super().__init__()
|
43 |
+
self.autocast = TorchAutocast(
|
44 |
+
enabled=True, device_type='cuda', dtype=torch.float16)
|
45 |
+
# Vocoder
|
46 |
+
_file_1 = hf_hub_download(
|
47 |
+
repo_id='facebook/audiogen-medium',
|
48 |
+
filename="compression_state_dict.bin",
|
49 |
+
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
|
50 |
+
library_name="audiocraft",
|
51 |
+
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
|
52 |
+
pkg = torch.load(_file_1, map_location='cpu')
|
53 |
+
# kwargs = OmegaConf.create(pkg['xp.cfg'])
|
54 |
+
# kwargs.device = 'cpu'
|
55 |
+
decoder = SEANetDecoder()
|
56 |
+
quantizer = ResidualVectorQuantizer()
|
57 |
+
self.compression_model = EncodecModel(decoder=decoder,
|
58 |
+
quantizer=quantizer,
|
59 |
+
frame_rate=50,
|
60 |
+
renormalize=False,
|
61 |
+
sample_rate=16000,
|
62 |
+
channels=1,
|
63 |
+
causal=False) #.to(cfg.device)
|
64 |
+
# self.compression_model = self.get_compression_model(cfg)
|
65 |
+
self.compression_model.load_state_dict(pkg['best_state'], strict=False) # ckpt has also unused encoder weights
|
66 |
+
self.resample_fn = torchaudio.transforms.Resample(16000, 24000) # AudioGen = 16KHZ StyleTTS2 = 24 KHz / MMSTTS = 24 KHz
|
67 |
+
# # T5 &
|
68 |
+
# LM
|
69 |
|
70 |
+
_file_2 = hf_hub_download(
|
71 |
+
repo_id='facebook/audiogen-medium',
|
72 |
+
filename="state_dict.bin",
|
73 |
+
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
|
74 |
+
library_name="audiocraft",
|
75 |
+
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
|
76 |
+
pkg = torch.load(_file_2, map_location='cpu')
|
77 |
+
cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
|
78 |
+
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
|
79 |
+
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
80 |
+
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
81 |
+
# print('___________________________CFG___________________',cfg,'\n=======================')
|
82 |
+
# kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
|
83 |
+
# print('___________________________Kwarg___________________',kwargs,'\n=======================')
|
84 |
+
_best = pkg['best_state']
|
85 |
+
# _best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
|
86 |
+
# _best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
|
87 |
+
_best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float)
|
88 |
+
_best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float)
|
89 |
+
self.lm = LMModel() #to(torch.float16)
|
90 |
+
self.lm.load_state_dict(pkg['best_state'],
|
91 |
+
strict=True)
|
92 |
+
#
|
93 |
+
self.lm.eval()
|
94 |
+
self.compression_model.eval()
|
95 |
+
|
96 |
def generate(self,
|
97 |
+
prompt='dogs mewo',
|
98 |
duration=2.24, ## seconds of audio
|
99 |
):
|
100 |
|
101 |
with torch.no_grad():
|
102 |
+
|
103 |
+
with self.autocast:
|
104 |
+
# LM
|
105 |
+
gen_tokens = self.lm.generate(
|
106 |
+
text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT, # '' for null condition, # ['trance', 'dogs meow', '', '']
|
107 |
+
max_tokens=int(duration / (N_REPEAT * self.lm.n_draw) * self.compression_model.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
|
108 |
+
|
109 |
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
|
110 |
|
111 |
x = x[:, 0, :] # last samples have splash sounds DISCARD 25000 last samples
|
|
|
121 |
|
122 |
print(x.abs().max(), 'MAX')
|
123 |
return x / (x.abs().max() + 1e-7)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/conditioners.py
DELETED
@@ -1,71 +0,0 @@
|
|
1 |
-
import warnings
|
2 |
-
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
3 |
-
import torch
|
4 |
-
from torch import nn
|
5 |
-
|
6 |
-
|
7 |
-
class T5Conditioner(nn.Module):
|
8 |
-
|
9 |
-
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
|
10 |
-
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
|
11 |
-
"google/flan-t5-xl", "google/flan-t5-xxl"]
|
12 |
-
MODELS_DIMS = {
|
13 |
-
"t5-small": 512,
|
14 |
-
"t5-base": 768,
|
15 |
-
"t5-large": 1024,
|
16 |
-
"t5-3b": 1024,
|
17 |
-
"t5-11b": 1024,
|
18 |
-
"google/flan-t5-small": 512,
|
19 |
-
"google/flan-t5-base": 768,
|
20 |
-
"google/flan-t5-large": 1024,
|
21 |
-
"google/flan-t5-3b": 1024,
|
22 |
-
"google/flan-t5-11b": 1024,
|
23 |
-
}
|
24 |
-
|
25 |
-
def __init__(self,
|
26 |
-
name,
|
27 |
-
output_dim,
|
28 |
-
device='cuda:0',
|
29 |
-
finetune=False):
|
30 |
-
print(f'{finetune=}')
|
31 |
-
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
|
32 |
-
super().__init__()
|
33 |
-
self.dim = self.MODELS_DIMS[name]
|
34 |
-
self.output_dim = output_dim
|
35 |
-
self.output_proj = nn.Linear(self.dim, output_dim)
|
36 |
-
self.device = device
|
37 |
-
self.name = name
|
38 |
-
|
39 |
-
self.t5_tokenizer = T5Tokenizer.from_pretrained(name, legacy=True)
|
40 |
-
t5 = T5EncoderModel.from_pretrained(name).eval() #.train(mode=finetune)
|
41 |
-
if finetune:
|
42 |
-
self.t5 = t5
|
43 |
-
else:
|
44 |
-
# this makes sure that the t5 models is not part
|
45 |
-
# of the saved checkpoint
|
46 |
-
self.__dict__['t5'] = t5.to(device)
|
47 |
-
|
48 |
-
|
49 |
-
def tokenize(self, x):
|
50 |
-
|
51 |
-
entries = [xi if xi is not None else "" for xi in x]
|
52 |
-
|
53 |
-
|
54 |
-
inputs = self.t5_tokenizer(entries,
|
55 |
-
return_tensors='pt',
|
56 |
-
padding=True).to(self.device)
|
57 |
-
|
58 |
-
return inputs # 'input_ids' 'attentio mask'
|
59 |
-
|
60 |
-
def forward(self, descriptions):
|
61 |
-
|
62 |
-
d = self.tokenize(descriptions)
|
63 |
-
|
64 |
-
with torch.no_grad():
|
65 |
-
embeds = self.t5(input_ids=d['input_ids'],
|
66 |
-
attention_mask=d['attention_mask']
|
67 |
-
).last_hidden_state # no kvcache for txt conditioning
|
68 |
-
embeds = self.output_proj(embeds.to(self.output_proj.weight))
|
69 |
-
embeds = (embeds * d['attention_mask'].unsqueeze(-1))
|
70 |
-
|
71 |
-
return embeds # , d['attention_mask']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/lm.py
CHANGED
@@ -1,10 +1,69 @@
|
|
1 |
import torch
|
2 |
-
import torch.nn.functional as F
|
3 |
from audiocraft.transformer import StreamingTransformer
|
4 |
from torch import nn
|
5 |
-
from
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
class LMModel(nn.Module):
|
10 |
|
@@ -16,8 +75,7 @@ class LMModel(nn.Module):
|
|
16 |
hidden_scale = 4, # FFN of Transformer
|
17 |
):
|
18 |
super().__init__()
|
19 |
-
self.
|
20 |
-
output_dim=dim)
|
21 |
self.card = card # 2048 ?
|
22 |
self.n_draw = 6 # replicate so many times the generation of each text in batch
|
23 |
# the batch is more expensive than n_draw as it re-runs the model bs times
|
@@ -49,14 +107,14 @@ class LMModel(nn.Module):
|
|
49 |
cross_attention_src=condition_tensors,
|
50 |
token_count=token_count
|
51 |
)
|
52 |
-
|
53 |
logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(self.n_q)], dim=1)#[2*bs,4,1,2048]
|
54 |
-
|
55 |
logits = 3 * logits[:bs, :, :, :] - 2 * logits[bs:, :, :, :] # [3, 4, 1, 2048]
|
56 |
-
|
57 |
|
58 |
# SAMPLE TOP K
|
59 |
-
k = 400 # 450 is nice sound still train honk is clear!
|
60 |
p = torch.softmax(logits, dim=3)
|
61 |
top_k_value, _ = torch.topk(p, k, dim=3) # [3, 4, 1, k]
|
62 |
min_value_top_k = top_k_value[:, :, :, -1:]
|
@@ -67,36 +125,31 @@ class LMModel(nn.Module):
|
|
67 |
p = p.reshape(bs * self.n_q, 2048)
|
68 |
out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
|
69 |
num_samples=self.n_draw,
|
70 |
-
replacement=
|
71 |
# print('DRAW','c', out)
|
72 |
return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
|
73 |
|
74 |
@torch.no_grad()
|
75 |
def generate(self,
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
bs
|
81 |
-
text_condition = torch.cat(
|
82 |
-
[
|
83 |
-
text_condition,
|
84 |
-
torch.zeros_like(text_condition)
|
85 |
-
], 0)
|
86 |
out_codes = torch.full((bs,
|
87 |
self.n_draw,
|
88 |
4,
|
89 |
4 + 3 + max_tokens), # 4 + max_tokens + 4-1 to have sufficient to index the 1st antidiagonal of 4x4 + 4 xtra tokens
|
90 |
self.card,
|
91 |
dtype=torch.long,
|
92 |
-
device=
|
93 |
# =========================================
|
94 |
for offset in range(0, max_tokens + 4 - 1): # max_tokens + n_q - 1
|
95 |
|
96 |
# extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
|
97 |
next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
|
98 |
#gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
|
99 |
-
condition_tensors=
|
100 |
token_count=offset) # [bs, n_draw, 4]
|
101 |
|
102 |
# Fill of next_token should be also placed on antidiagonal [not column]
|
@@ -138,7 +191,7 @@ class LMModel(nn.Module):
|
|
138 |
pass #print('No delete anti-diag')
|
139 |
|
140 |
out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
|
141 |
-
|
142 |
# EXTRACT COLUMNS AS ALIGN IS ALREADY DONE by FILLING DIAGONALLY
|
143 |
out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens) # [bs, 4, duration*n_draw] DISCARD FILL 2048
|
144 |
|
|
|
1 |
import torch
|
|
|
2 |
from audiocraft.transformer import StreamingTransformer
|
3 |
from torch import nn
|
4 |
+
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
5 |
+
|
6 |
+
class TorchAutocast:
|
7 |
+
|
8 |
+
def __init__(self, enabled: bool, *args, **kwargs):
|
9 |
+
self.autocast = torch.autocast(*args, **kwargs) if enabled else None
|
10 |
+
|
11 |
+
def __enter__(self):
|
12 |
+
if self.autocast is None:
|
13 |
+
return
|
14 |
+
try:
|
15 |
+
self.autocast.__enter__()
|
16 |
+
except RuntimeError:
|
17 |
+
device = self.autocast.device
|
18 |
+
dtype = self.autocast.fast_dtype
|
19 |
+
raise RuntimeError(
|
20 |
+
f"There was an error autocasting with dtype={dtype} device={device}\n"
|
21 |
+
"If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
|
22 |
+
)
|
23 |
+
|
24 |
+
def __exit__(self, *args, **kwargs):
|
25 |
+
if self.autocast is None:
|
26 |
+
return
|
27 |
+
self.autocast.__exit__(*args, **kwargs)
|
28 |
+
|
29 |
|
30 |
+
class T5(nn.Module):
|
31 |
+
|
32 |
+
def __init__(self):
|
33 |
+
# run this from within lm so it autocasts thus match exact values of t5 in official audiogen
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
self.dim = 1024
|
37 |
+
self.output_dim = 1536
|
38 |
+
self.output_proj = nn.Linear(self.dim, self.output_dim)
|
39 |
+
self.autocast = TorchAutocast(enabled=True, device_type='cuda', dtype=torch.float)
|
40 |
+
self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True)
|
41 |
+
t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False)
|
42 |
+
|
43 |
+
# this makes sure that the t5 models is not part
|
44 |
+
# of the saved checkpoint
|
45 |
+
self.__dict__['t5'] = t5.to('cuda:0')
|
46 |
+
|
47 |
+
def forward(self, prompt):
|
48 |
+
with torch.set_grad_enabled(False), self.autocast:
|
49 |
+
|
50 |
+
bs = len(prompt) // 2
|
51 |
+
|
52 |
+
# txt 2 hidden
|
53 |
+
|
54 |
+
d = self.t5_tokenizer(prompt,
|
55 |
+
return_tensors='pt',
|
56 |
+
padding=True).to(self.output_proj.bias.device)
|
57 |
+
d['attention_mask'][bs:, :] = 0 # null condition t5 attn_mask should be zero
|
58 |
+
|
59 |
+
x = self.t5(input_ids=d['input_ids'],
|
60 |
+
attention_mask=d['attention_mask']).last_hidden_state # no kv
|
61 |
+
|
62 |
+
# output_proj as float32
|
63 |
+
print('BEF PROJ',x[0, :, :].sum(), x[1, :, :].sum(), self.output_proj.weight.sum(), self.output_proj.weight.dtype, self.output_proj.bias.sum(), 'GEN\n\n143')
|
64 |
+
x = self.output_proj(x) # nn.Linear() - produces different result if there is no duplicate txt condition here
|
65 |
+
x[bs:, :, :] = 0 # venv/../site-packages/audiocraft/modules/conditioners.py -> tokenize()
|
66 |
+
return x
|
67 |
|
68 |
class LMModel(nn.Module):
|
69 |
|
|
|
75 |
hidden_scale = 4, # FFN of Transformer
|
76 |
):
|
77 |
super().__init__()
|
78 |
+
self.t5 = T5()
|
|
|
79 |
self.card = card # 2048 ?
|
80 |
self.n_draw = 6 # replicate so many times the generation of each text in batch
|
81 |
# the batch is more expensive than n_draw as it re-runs the model bs times
|
|
|
107 |
cross_attention_src=condition_tensors,
|
108 |
token_count=token_count
|
109 |
)
|
110 |
+
print(out.sum(), out.dtype, 'TRSF out cust')
|
111 |
logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(self.n_q)], dim=1)#[2*bs,4,1,2048]
|
112 |
+
|
113 |
logits = 3 * logits[:bs, :, :, :] - 2 * logits[bs:, :, :, :] # [3, 4, 1, 2048]
|
114 |
+
print(logits.sum(0).sum(2), logits.shape, 'SAMPL custom')
|
115 |
|
116 |
# SAMPLE TOP K
|
117 |
+
k = 1#400 # 450 is nice sound still train honk is clear!
|
118 |
p = torch.softmax(logits, dim=3)
|
119 |
top_k_value, _ = torch.topk(p, k, dim=3) # [3, 4, 1, k]
|
120 |
min_value_top_k = top_k_value[:, :, :, -1:]
|
|
|
125 |
p = p.reshape(bs * self.n_q, 2048)
|
126 |
out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
|
127 |
num_samples=self.n_draw,
|
128 |
+
replacement=True) # [bs*4, self.n_draw]
|
129 |
# print('DRAW','c', out)
|
130 |
return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
|
131 |
|
132 |
@torch.no_grad()
|
133 |
def generate(self,
|
134 |
+
max_tokens=None,
|
135 |
+
text_condition=None
|
136 |
+
):
|
137 |
+
x = self.t5(text_condition)
|
138 |
+
bs = x.shape[0] // 2 # has null conditions - bs*2*N_REPEAT applys in builders.py
|
|
|
|
|
|
|
|
|
|
|
139 |
out_codes = torch.full((bs,
|
140 |
self.n_draw,
|
141 |
4,
|
142 |
4 + 3 + max_tokens), # 4 + max_tokens + 4-1 to have sufficient to index the 1st antidiagonal of 4x4 + 4 xtra tokens
|
143 |
self.card,
|
144 |
dtype=torch.long,
|
145 |
+
device=x.device) # [bs, n_draw, 4, dur]
|
146 |
# =========================================
|
147 |
for offset in range(0, max_tokens + 4 - 1): # max_tokens + n_q - 1
|
148 |
|
149 |
# extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
|
150 |
next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
|
151 |
#gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
|
152 |
+
condition_tensors=x, # utilisation of the attention mask of txt condition ?
|
153 |
token_count=offset) # [bs, n_draw, 4]
|
154 |
|
155 |
# Fill of next_token should be also placed on antidiagonal [not column]
|
|
|
191 |
pass #print('No delete anti-diag')
|
192 |
|
193 |
out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
|
194 |
+
print('\nFULL FINAL TOKENS UNFILT\n', out_codes[:, 0, :, 4:max_tokens+4], out_codes[0, 0, :, 4:max_tokens+4].shape)
|
195 |
# EXTRACT COLUMNS AS ALIGN IS ALREADY DONE by FILLING DIAGONALLY
|
196 |
out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens) # [bs, 4, duration*n_draw] DISCARD FILL 2048
|
197 |
|
audiocraft/seanet.py
CHANGED
@@ -82,21 +82,36 @@ class SEANetResnetBlock(nn.Module):
|
|
82 |
|
83 |
|
84 |
class SEANetDecoder(nn.Module):
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
activation_params: dict = {'alpha': 1.0},
|
90 |
-
final_activation
|
91 |
-
final_activation_params
|
92 |
-
norm
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
100 |
super().__init__()
|
101 |
self.dimension = dimension
|
102 |
self.channels = channels
|
|
|
82 |
|
83 |
|
84 |
class SEANetDecoder(nn.Module):
|
85 |
+
# channels=1 dimension=128 n_filters=64 n_residual_layers=1 ratios=[8, 5, 4, 2]
|
86 |
+
# activation='ELU' activation_params={'alpha': 1.0}, final_activation=None
|
87 |
+
# final_activation_params=None norm='weight_norm'
|
88 |
+
# norm_params={} kernel_size=7 last_kernel_size=7 residual_kernel_size=3 dilation_base=2
|
89 |
+
# causal=False pad_mode='constant'
|
90 |
+
# true_skip=True compress=2 lstm=2 disable_norm_outer_blocks=0 trim_right_ratio=1.0
|
91 |
+
|
92 |
+
def __init__(self,
|
93 |
+
channels = 1,
|
94 |
+
dimension = 128,
|
95 |
+
n_filters = 64,
|
96 |
+
n_residual_layers = 1,
|
97 |
+
ratios = [8, 5, 4, 2],
|
98 |
+
activation = 'ELU',
|
99 |
activation_params: dict = {'alpha': 1.0},
|
100 |
+
final_activation = None,
|
101 |
+
final_activation_params = None,
|
102 |
+
norm = 'weight_norm',
|
103 |
+
norm_params = {},
|
104 |
+
kernel_size = 7,
|
105 |
+
last_kernel_size = 7,
|
106 |
+
residual_kernel_size = 3,
|
107 |
+
dilation_base = 2,
|
108 |
+
causal = False,
|
109 |
+
pad_mode = 'constant',
|
110 |
+
true_skip = True,
|
111 |
+
compress = 2,
|
112 |
+
lstm = 2,
|
113 |
+
disable_norm_outer_blocks = 0,
|
114 |
+
trim_right_ratio = 1.0):
|
115 |
super().__init__()
|
116 |
self.dimension = dimension
|
117 |
self.channels = channels
|
audiocraft/transformer.py
CHANGED
@@ -53,16 +53,13 @@ class StreamingMultiheadAttention(nn.Module):
|
|
53 |
# 1st projected makes k,v (instantaneous)
|
54 |
# Here else is self_attention for audio with itself (above is cross attention txt)
|
55 |
|
56 |
-
|
57 |
# HISTORY - DIFFERENT FOR EACH TRANSF LAYER
|
58 |
|
59 |
-
projected = nn.functional.linear(query, self.in_proj_weight)
|
60 |
|
61 |
bound_layout = "b h p t d"
|
62 |
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
63 |
q, k, v = packed.unbind(dim=2)
|
64 |
-
|
65 |
-
|
66 |
if self.k_history is not None:
|
67 |
# flush
|
68 |
if self.k_history.shape[2] > 71:
|
@@ -84,7 +81,7 @@ class StreamingMultiheadAttention(nn.Module):
|
|
84 |
|
85 |
|
86 |
# KV COMPLETION ONLY ON SELF ATTENTION
|
87 |
-
|
88 |
x = torch.nn.functional.scaled_dot_product_attention(
|
89 |
q, k, v, is_causal=False, dropout_p=0
|
90 |
)
|
@@ -93,15 +90,24 @@ class StreamingMultiheadAttention(nn.Module):
|
|
93 |
return x
|
94 |
|
95 |
|
96 |
-
class StreamingTransformerLayer(nn.Module):
|
97 |
|
|
|
|
|
98 |
def __init__(self,
|
99 |
d_model,
|
100 |
num_heads,
|
101 |
dim_feedforward):
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
|
107 |
num_heads=num_heads)
|
@@ -116,20 +122,13 @@ class StreamingTransformerLayer(nn.Module):
|
|
116 |
|
117 |
|
118 |
def forward(self,
|
119 |
-
|
120 |
cross_attention_src=None): # txtcond
|
121 |
-
|
122 |
-
|
123 |
-
x = src
|
124 |
-
|
125 |
x = x + self.self_attn(self.norm1(x))
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
query = self.norm_cross(x),
|
130 |
-
key = cross_attention_src,
|
131 |
-
value = cross_attention_src) # txtcondition
|
132 |
-
|
133 |
x = x + self.linear2(F.gelu(self.linear1( self.norm2(x) )))
|
134 |
return x
|
135 |
|
@@ -164,9 +163,12 @@ class StreamingTransformer(nn.Module):
|
|
164 |
x,
|
165 |
token_count=None,
|
166 |
cross_attention_src=None):
|
167 |
-
|
|
|
168 |
if self.positional_embedding in ['sin', 'sin_rope']:
|
169 |
-
pos_emb = create_sin_embedding(torch.
|
|
|
|
|
170 |
|
171 |
x = x + pos_emb
|
172 |
for j, lay in enumerate(self.layers):
|
|
|
53 |
# 1st projected makes k,v (instantaneous)
|
54 |
# Here else is self_attention for audio with itself (above is cross attention txt)
|
55 |
|
|
|
56 |
# HISTORY - DIFFERENT FOR EACH TRANSF LAYER
|
57 |
|
58 |
+
projected = nn.functional.linear(query, self.in_proj_weight, None) # here we have different floating values from official
|
59 |
|
60 |
bound_layout = "b h p t d"
|
61 |
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
62 |
q, k, v = packed.unbind(dim=2)
|
|
|
|
|
63 |
if self.k_history is not None:
|
64 |
# flush
|
65 |
if self.k_history.shape[2] > 71:
|
|
|
81 |
|
82 |
|
83 |
# KV COMPLETION ONLY ON SELF ATTENTION
|
84 |
+
|
85 |
x = torch.nn.functional.scaled_dot_product_attention(
|
86 |
q, k, v, is_causal=False, dropout_p=0
|
87 |
)
|
|
|
90 |
return x
|
91 |
|
92 |
|
|
|
93 |
|
94 |
+
class StreamingTransformerLayer(nn.TransformerEncoderLayer):
|
95 |
+
|
96 |
def __init__(self,
|
97 |
d_model,
|
98 |
num_heads,
|
99 |
dim_feedforward):
|
100 |
+
|
101 |
+
super().__init__(d_model,
|
102 |
+
num_heads,
|
103 |
+
dim_feedforward=dim_feedforward,
|
104 |
+
dropout=0.0,
|
105 |
+
device='cuda',
|
106 |
+
dtype=torch.float32,
|
107 |
+
batch_first=True,
|
108 |
+
norm_first=True,
|
109 |
+
activation='gelu')
|
110 |
+
# super().__init__()
|
111 |
|
112 |
self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
|
113 |
num_heads=num_heads)
|
|
|
122 |
|
123 |
|
124 |
def forward(self,
|
125 |
+
x,
|
126 |
cross_attention_src=None): # txtcond
|
127 |
+
# x = src
|
|
|
|
|
|
|
128 |
x = x + self.self_attn(self.norm1(x))
|
129 |
+
x = x + self.cross_attention(query = self.norm_cross(x),
|
130 |
+
key = cross_attention_src,
|
131 |
+
value = cross_attention_src) # txtcondition
|
|
|
|
|
|
|
|
|
132 |
x = x + self.linear2(F.gelu(self.linear1( self.norm2(x) )))
|
133 |
return x
|
134 |
|
|
|
163 |
x,
|
164 |
token_count=None,
|
165 |
cross_attention_src=None):
|
166 |
+
#cross_attention_src = (torch.arange(1536)[None, None, :] * torch.arange(6).reshape(2,3,1)).to(torch.float).to(x.device)
|
167 |
+
print(x.sum(), cross_attention_src[0, :, :].sum(), cross_attention_src[1, :, :].sum(), x.dtype, cross_attention_src.dtype, 'Xattnc')
|
168 |
if self.positional_embedding in ['sin', 'sin_rope']:
|
169 |
+
pos_emb = create_sin_embedding(torch.zeros(x.shape[0], 1, 1, device=x.device) + token_count,
|
170 |
+
1536,
|
171 |
+
max_period=self.max_period)
|
172 |
|
173 |
x = x + pos_emb
|
174 |
for j, lay in enumerate(self.layers):
|
audiocraft/vq.py
CHANGED
@@ -145,46 +145,27 @@ class ResidualVectorQuantization(nn.Module):
|
|
145 |
layer = self.layers[i]
|
146 |
quantized = layer.decode(indices)
|
147 |
quantized_out = quantized_out + quantized
|
148 |
-
return quantized_out
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
# ------------------------------------- END core_vq.py
|
154 |
|
155 |
|
156 |
class ResidualVectorQuantizer(nn.Module):
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
n_q (int): Number of residual vector quantizers used.
|
162 |
-
q_dropout (bool): Random quantizer drop out at train time.
|
163 |
-
bins (int): Codebook size.
|
164 |
-
decay (float): Decay for exponential moving average over the codebooks.
|
165 |
-
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
166 |
-
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
167 |
-
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
168 |
-
that have an exponential moving average cluster size less than the specified threshold with
|
169 |
-
randomly selected vector from the current batch.
|
170 |
-
orthogonal_reg_weight (float): Orthogonal regularization weights.
|
171 |
-
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
|
172 |
-
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
|
173 |
-
for orthogonal regularization.
|
174 |
-
"""
|
175 |
def __init__(
|
176 |
self,
|
177 |
-
dimension
|
178 |
-
n_q
|
179 |
-
q_dropout
|
180 |
-
bins
|
181 |
-
decay
|
182 |
-
kmeans_init
|
183 |
-
kmeans_iters
|
184 |
-
threshold_ema_dead_code
|
185 |
-
orthogonal_reg_weight
|
186 |
-
orthogonal_reg_active_codes_only
|
187 |
-
orthogonal_reg_max_codes
|
188 |
):
|
189 |
super().__init__()
|
190 |
self.max_n_q = n_q
|
|
|
145 |
layer = self.layers[i]
|
146 |
quantized = layer.decode(indices)
|
147 |
quantized_out = quantized_out + quantized
|
148 |
+
return quantized_out
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
|
151 |
class ResidualVectorQuantizer(nn.Module):
|
152 |
+
|
153 |
+
# dimension=128 n_q=4 q_dropout=False bins=2048 decay=0.99 kmeans_init=True kmeans_iters=50 threshold_ema_dead_code=2
|
154 |
+
# orthogonal_reg_weight=0.0 orthogonal_reg_active_codes_only=False orthogonal_reg_max_codes=None
|
155 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
def __init__(
|
157 |
self,
|
158 |
+
dimension = 128,
|
159 |
+
n_q = 4,
|
160 |
+
q_dropout = False,
|
161 |
+
bins = 2048,
|
162 |
+
decay = 0.99,
|
163 |
+
kmeans_init = True,
|
164 |
+
kmeans_iters = 50,
|
165 |
+
threshold_ema_dead_code = 2,
|
166 |
+
orthogonal_reg_weight = 0.0,
|
167 |
+
orthogonal_reg_active_codes_only = False,
|
168 |
+
orthogonal_reg_max_codes = None,
|
169 |
):
|
170 |
super().__init__()
|
171 |
self.max_n_q = n_q
|