Dionyssos commited on
Commit
0230db1
·
1 Parent(s): 4898f81

AudioGen dtypes

Browse files
audiocraft/builders.py CHANGED
@@ -1,15 +1,17 @@
1
  import omegaconf
2
  import torchaudio
3
  import torch
 
4
  import numpy as np
5
  from huggingface_hub import hf_hub_download
6
  import os
7
  from omegaconf import OmegaConf
8
  from .encodec import EncodecModel
9
- from .lm import LMModel
10
  from .seanet import SEANetDecoder
11
  from .vq import ResidualVectorQuantizer
12
 
 
13
  N_REPEAT = 4 # num (virtual batch_size) clones of audio sounds
14
 
15
  def _shift(x):
@@ -29,13 +31,7 @@ def _delete_param(cfg, full_name):
29
  if parts[-1] in cfg:
30
  del cfg[parts[-1]]
31
  OmegaConf.set_struct(cfg, True)
32
-
33
-
34
-
35
- def dict_from_config(cfg):
36
- dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
37
- return dct
38
-
39
 
40
  class AudioGen(torch.nn.Module):
41
 
@@ -44,21 +40,72 @@ class AudioGen(torch.nn.Module):
44
  def __init__(self):
45
 
46
  super().__init__()
47
- self.load_compression_model()
48
- self.load_lm_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # AudioGen = 16KHZ StyleTTS2 = 24 KHz / MMSTTS = 24 KHz
51
- self.resample_fn = torchaudio.transforms.Resample(16000, 24000)
52
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def generate(self,
54
- descriptions,
55
  duration=2.24, ## seconds of audio
56
  ):
57
 
58
  with torch.no_grad():
59
- gen_tokens = self.lm.generate(
60
- descriptions=[descriptions] * N_REPEAT,
61
- max_tokens=int(duration / (N_REPEAT * self.lm.n_draw) * self.compression_model.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
 
 
 
 
62
  x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
63
 
64
  x = x[:, 0, :] # last samples have splash sounds DISCARD 25000 last samples
@@ -74,94 +121,3 @@ class AudioGen(torch.nn.Module):
74
 
75
  print(x.abs().max(), 'MAX')
76
  return x / (x.abs().max() + 1e-7)
77
-
78
- def get_quantizer(self, quantizer, cfg, dimension):
79
- klass = {
80
- 'no_quant': None,
81
- 'rvq': ResidualVectorQuantizer
82
- }[quantizer]
83
- kwargs = dict_from_config(getattr(cfg, quantizer))
84
- if quantizer != 'no_quant':
85
- kwargs['dimension'] = dimension
86
- return klass(**kwargs)
87
-
88
- def get_encodec_autoencoder(self, cfg):
89
- kwargs = dict_from_config(getattr(cfg, 'seanet'))
90
- _ = kwargs.pop('encoder')
91
- decoder_override_kwargs = kwargs.pop('decoder')
92
- decoder_kwargs = {**kwargs, **decoder_override_kwargs}
93
- decoder = SEANetDecoder(**decoder_kwargs)
94
- return decoder
95
-
96
- def get_compression_model(self, cfg):
97
- """Instantiate a compression model."""
98
- if cfg.compression_model == 'encodec':
99
- kwargs = dict_from_config(getattr(cfg, 'encodec'))
100
- quantizer_name = kwargs.pop('quantizer')
101
- decoder = self.get_encodec_autoencoder(cfg)
102
- quantizer = self.get_quantizer(quantizer_name, cfg, 128)
103
- renormalize = kwargs.pop('renormalize', False)
104
- # deprecated params
105
- # print(f'{frame_rate=} {encoder.dimension=}') frame_rate=50 encoder.dimension=128
106
- kwargs.pop('renorm', None)
107
- # print('\n______!____________\n', kwargs, '\n______!____________\n')
108
- # ______!____________
109
- # {'autoencoder': 'seanet', 'sample_rate': 16000, 'channels': 1, 'causal': False}
110
- # ______!____________
111
-
112
- return EncodecModel(decoder=decoder,
113
- quantizer=quantizer,
114
- frame_rate=50,
115
- renormalize=renormalize,
116
- sample_rate=16000,
117
- channels=1,
118
- causal=False
119
- ).to(cfg.device)
120
- else:
121
- raise KeyError(f"Unexpected compression model {cfg.compression_model}")
122
-
123
- def load_compression_model(self):
124
- file = hf_hub_download(
125
- repo_id='facebook/audiogen-medium',
126
- filename="compression_state_dict.bin",
127
- cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
128
- library_name="audiocraft",
129
- library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
130
- pkg = torch.load(file, map_location='cpu')
131
- # if 'pretrained' in pkg:
132
- # print('NO RPtrained\n=\n=\n=\n=\n=')
133
- # return EncodecModel.get_pretrained(pkg['pretrained'], device='cpu')
134
- cfg = OmegaConf.create(pkg['xp.cfg'])
135
- cfg.device = 'cpu'
136
- model = self.get_compression_model(cfg)
137
- model.load_state_dict(pkg['best_state'], strict=False) # ckpt has also unused encoder weights
138
- # return model
139
- self.compression_model = model
140
-
141
- def load_lm_model(self):
142
- file = hf_hub_download(
143
- repo_id='facebook/audiogen-medium',
144
- filename="state_dict.bin",
145
- cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
146
- library_name="audiocraft",
147
- library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
148
- pkg = torch.load(file,
149
- map_location='cpu')
150
- cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
151
- _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
152
- _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
153
- _delete_param(cfg, 'conditioners.args.drop_desc_p')
154
- print('___________________________CFG___________________',cfg,'\n=======================')
155
- kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
156
- print('___________________________Kwarg___________________',kwargs,'\n=======================')
157
- model = LMModel().to(getattr(torch, cfg.dtype)) #.to(cfg.device)
158
- _best = pkg['best_state']
159
- _best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
160
- _best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
161
- model.load_state_dict(pkg['best_state'], strict=True)
162
- # model.cfg = cfg
163
- self.lm = model.to(torch.float)
164
-
165
- # def _flush(self):
166
- # self.lm._flush() # already done in lm generate at end
167
-
 
1
  import omegaconf
2
  import torchaudio
3
  import torch
4
+ from torch import nn
5
  import numpy as np
6
  from huggingface_hub import hf_hub_download
7
  import os
8
  from omegaconf import OmegaConf
9
  from .encodec import EncodecModel
10
+ from .lm import LMModel, TorchAutocast
11
  from .seanet import SEANetDecoder
12
  from .vq import ResidualVectorQuantizer
13
 
14
+
15
  N_REPEAT = 4 # num (virtual batch_size) clones of audio sounds
16
 
17
  def _shift(x):
 
31
  if parts[-1] in cfg:
32
  del cfg[parts[-1]]
33
  OmegaConf.set_struct(cfg, True)
34
+
 
 
 
 
 
 
35
 
36
  class AudioGen(torch.nn.Module):
37
 
 
40
  def __init__(self):
41
 
42
  super().__init__()
43
+ self.autocast = TorchAutocast(
44
+ enabled=True, device_type='cuda', dtype=torch.float16)
45
+ # Vocoder
46
+ _file_1 = hf_hub_download(
47
+ repo_id='facebook/audiogen-medium',
48
+ filename="compression_state_dict.bin",
49
+ cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
50
+ library_name="audiocraft",
51
+ library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
52
+ pkg = torch.load(_file_1, map_location='cpu')
53
+ # kwargs = OmegaConf.create(pkg['xp.cfg'])
54
+ # kwargs.device = 'cpu'
55
+ decoder = SEANetDecoder()
56
+ quantizer = ResidualVectorQuantizer()
57
+ self.compression_model = EncodecModel(decoder=decoder,
58
+ quantizer=quantizer,
59
+ frame_rate=50,
60
+ renormalize=False,
61
+ sample_rate=16000,
62
+ channels=1,
63
+ causal=False) #.to(cfg.device)
64
+ # self.compression_model = self.get_compression_model(cfg)
65
+ self.compression_model.load_state_dict(pkg['best_state'], strict=False) # ckpt has also unused encoder weights
66
+ self.resample_fn = torchaudio.transforms.Resample(16000, 24000) # AudioGen = 16KHZ StyleTTS2 = 24 KHz / MMSTTS = 24 KHz
67
+ # # T5 &
68
+ # LM
69
 
70
+ _file_2 = hf_hub_download(
71
+ repo_id='facebook/audiogen-medium',
72
+ filename="state_dict.bin",
73
+ cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
74
+ library_name="audiocraft",
75
+ library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
76
+ pkg = torch.load(_file_2, map_location='cpu')
77
+ cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
78
+ _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
79
+ _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
80
+ _delete_param(cfg, 'conditioners.args.drop_desc_p')
81
+ # print('___________________________CFG___________________',cfg,'\n=======================')
82
+ # kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
83
+ # print('___________________________Kwarg___________________',kwargs,'\n=======================')
84
+ _best = pkg['best_state']
85
+ # _best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
86
+ # _best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
87
+ _best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float)
88
+ _best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float)
89
+ self.lm = LMModel() #to(torch.float16)
90
+ self.lm.load_state_dict(pkg['best_state'],
91
+ strict=True)
92
+ #
93
+ self.lm.eval()
94
+ self.compression_model.eval()
95
+
96
  def generate(self,
97
+ prompt='dogs mewo',
98
  duration=2.24, ## seconds of audio
99
  ):
100
 
101
  with torch.no_grad():
102
+
103
+ with self.autocast:
104
+ # LM
105
+ gen_tokens = self.lm.generate(
106
+ text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT, # '' for null condition, # ['trance', 'dogs meow', '', '']
107
+ max_tokens=int(duration / (N_REPEAT * self.lm.n_draw) * self.compression_model.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
108
+
109
  x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
110
 
111
  x = x[:, 0, :] # last samples have splash sounds DISCARD 25000 last samples
 
121
 
122
  print(x.abs().max(), 'MAX')
123
  return x / (x.abs().max() + 1e-7)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/conditioners.py DELETED
@@ -1,71 +0,0 @@
1
- import warnings
2
- from transformers import T5EncoderModel, T5Tokenizer # type: ignore
3
- import torch
4
- from torch import nn
5
-
6
-
7
- class T5Conditioner(nn.Module):
8
-
9
- MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
10
- "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
11
- "google/flan-t5-xl", "google/flan-t5-xxl"]
12
- MODELS_DIMS = {
13
- "t5-small": 512,
14
- "t5-base": 768,
15
- "t5-large": 1024,
16
- "t5-3b": 1024,
17
- "t5-11b": 1024,
18
- "google/flan-t5-small": 512,
19
- "google/flan-t5-base": 768,
20
- "google/flan-t5-large": 1024,
21
- "google/flan-t5-3b": 1024,
22
- "google/flan-t5-11b": 1024,
23
- }
24
-
25
- def __init__(self,
26
- name,
27
- output_dim,
28
- device='cuda:0',
29
- finetune=False):
30
- print(f'{finetune=}')
31
- assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
32
- super().__init__()
33
- self.dim = self.MODELS_DIMS[name]
34
- self.output_dim = output_dim
35
- self.output_proj = nn.Linear(self.dim, output_dim)
36
- self.device = device
37
- self.name = name
38
-
39
- self.t5_tokenizer = T5Tokenizer.from_pretrained(name, legacy=True)
40
- t5 = T5EncoderModel.from_pretrained(name).eval() #.train(mode=finetune)
41
- if finetune:
42
- self.t5 = t5
43
- else:
44
- # this makes sure that the t5 models is not part
45
- # of the saved checkpoint
46
- self.__dict__['t5'] = t5.to(device)
47
-
48
-
49
- def tokenize(self, x):
50
-
51
- entries = [xi if xi is not None else "" for xi in x]
52
-
53
-
54
- inputs = self.t5_tokenizer(entries,
55
- return_tensors='pt',
56
- padding=True).to(self.device)
57
-
58
- return inputs # 'input_ids' 'attentio mask'
59
-
60
- def forward(self, descriptions):
61
-
62
- d = self.tokenize(descriptions)
63
-
64
- with torch.no_grad():
65
- embeds = self.t5(input_ids=d['input_ids'],
66
- attention_mask=d['attention_mask']
67
- ).last_hidden_state # no kvcache for txt conditioning
68
- embeds = self.output_proj(embeds.to(self.output_proj.weight))
69
- embeds = (embeds * d['attention_mask'].unsqueeze(-1))
70
-
71
- return embeds # , d['attention_mask']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/lm.py CHANGED
@@ -1,10 +1,69 @@
1
  import torch
2
- import torch.nn.functional as F
3
  from audiocraft.transformer import StreamingTransformer
4
  from torch import nn
5
- from audiocraft.conditioners import T5Conditioner
6
- import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class LMModel(nn.Module):
10
 
@@ -16,8 +75,7 @@ class LMModel(nn.Module):
16
  hidden_scale = 4, # FFN of Transformer
17
  ):
18
  super().__init__()
19
- self.condition_provider = T5Conditioner(name='t5-large',
20
- output_dim=dim)
21
  self.card = card # 2048 ?
22
  self.n_draw = 6 # replicate so many times the generation of each text in batch
23
  # the batch is more expensive than n_draw as it re-runs the model bs times
@@ -49,14 +107,14 @@ class LMModel(nn.Module):
49
  cross_attention_src=condition_tensors,
50
  token_count=token_count
51
  )
52
-
53
  logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(self.n_q)], dim=1)#[2*bs,4,1,2048]
54
-
55
  logits = 3 * logits[:bs, :, :, :] - 2 * logits[bs:, :, :, :] # [3, 4, 1, 2048]
56
-
57
 
58
  # SAMPLE TOP K
59
- k = 400 # 450 is nice sound still train honk is clear!
60
  p = torch.softmax(logits, dim=3)
61
  top_k_value, _ = torch.topk(p, k, dim=3) # [3, 4, 1, k]
62
  min_value_top_k = top_k_value[:, :, :, -1:]
@@ -67,36 +125,31 @@ class LMModel(nn.Module):
67
  p = p.reshape(bs * self.n_q, 2048)
68
  out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
69
  num_samples=self.n_draw,
70
- replacement=False) # [bs*4, self.n_draw]
71
  # print('DRAW','c', out)
72
  return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
73
 
74
  @torch.no_grad()
75
  def generate(self,
76
- descriptions = ['windy day', 'rain storm'],
77
- max_tokens = None):
78
-
79
- text_condition = self.condition_provider(descriptions)
80
- bs, _, _ = text_condition.shape
81
- text_condition = torch.cat(
82
- [
83
- text_condition,
84
- torch.zeros_like(text_condition)
85
- ], 0)
86
  out_codes = torch.full((bs,
87
  self.n_draw,
88
  4,
89
  4 + 3 + max_tokens), # 4 + max_tokens + 4-1 to have sufficient to index the 1st antidiagonal of 4x4 + 4 xtra tokens
90
  self.card,
91
  dtype=torch.long,
92
- device=text_condition.device) # [bs, n_draw, 4, dur]
93
  # =========================================
94
  for offset in range(0, max_tokens + 4 - 1): # max_tokens + n_q - 1
95
 
96
  # extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
97
  next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
98
  #gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
99
- condition_tensors=text_condition, # utilisation of the attention mask of txt condition ?
100
  token_count=offset) # [bs, n_draw, 4]
101
 
102
  # Fill of next_token should be also placed on antidiagonal [not column]
@@ -138,7 +191,7 @@ class LMModel(nn.Module):
138
  pass #print('No delete anti-diag')
139
 
140
  out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
141
- # END LOOP
142
  # EXTRACT COLUMNS AS ALIGN IS ALREADY DONE by FILLING DIAGONALLY
143
  out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens) # [bs, 4, duration*n_draw] DISCARD FILL 2048
144
 
 
1
  import torch
 
2
  from audiocraft.transformer import StreamingTransformer
3
  from torch import nn
4
+ from transformers import T5EncoderModel, T5Tokenizer # type: ignore
5
+
6
+ class TorchAutocast:
7
+
8
+ def __init__(self, enabled: bool, *args, **kwargs):
9
+ self.autocast = torch.autocast(*args, **kwargs) if enabled else None
10
+
11
+ def __enter__(self):
12
+ if self.autocast is None:
13
+ return
14
+ try:
15
+ self.autocast.__enter__()
16
+ except RuntimeError:
17
+ device = self.autocast.device
18
+ dtype = self.autocast.fast_dtype
19
+ raise RuntimeError(
20
+ f"There was an error autocasting with dtype={dtype} device={device}\n"
21
+ "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
22
+ )
23
+
24
+ def __exit__(self, *args, **kwargs):
25
+ if self.autocast is None:
26
+ return
27
+ self.autocast.__exit__(*args, **kwargs)
28
+
29
 
30
+ class T5(nn.Module):
31
+
32
+ def __init__(self):
33
+ # run this from within lm so it autocasts thus match exact values of t5 in official audiogen
34
+ super().__init__()
35
+
36
+ self.dim = 1024
37
+ self.output_dim = 1536
38
+ self.output_proj = nn.Linear(self.dim, self.output_dim)
39
+ self.autocast = TorchAutocast(enabled=True, device_type='cuda', dtype=torch.float)
40
+ self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True)
41
+ t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False)
42
+
43
+ # this makes sure that the t5 models is not part
44
+ # of the saved checkpoint
45
+ self.__dict__['t5'] = t5.to('cuda:0')
46
+
47
+ def forward(self, prompt):
48
+ with torch.set_grad_enabled(False), self.autocast:
49
+
50
+ bs = len(prompt) // 2
51
+
52
+ # txt 2 hidden
53
+
54
+ d = self.t5_tokenizer(prompt,
55
+ return_tensors='pt',
56
+ padding=True).to(self.output_proj.bias.device)
57
+ d['attention_mask'][bs:, :] = 0 # null condition t5 attn_mask should be zero
58
+
59
+ x = self.t5(input_ids=d['input_ids'],
60
+ attention_mask=d['attention_mask']).last_hidden_state # no kv
61
+
62
+ # output_proj as float32
63
+ print('BEF PROJ',x[0, :, :].sum(), x[1, :, :].sum(), self.output_proj.weight.sum(), self.output_proj.weight.dtype, self.output_proj.bias.sum(), 'GEN\n\n143')
64
+ x = self.output_proj(x) # nn.Linear() - produces different result if there is no duplicate txt condition here
65
+ x[bs:, :, :] = 0 # venv/../site-packages/audiocraft/modules/conditioners.py -> tokenize()
66
+ return x
67
 
68
  class LMModel(nn.Module):
69
 
 
75
  hidden_scale = 4, # FFN of Transformer
76
  ):
77
  super().__init__()
78
+ self.t5 = T5()
 
79
  self.card = card # 2048 ?
80
  self.n_draw = 6 # replicate so many times the generation of each text in batch
81
  # the batch is more expensive than n_draw as it re-runs the model bs times
 
107
  cross_attention_src=condition_tensors,
108
  token_count=token_count
109
  )
110
+ print(out.sum(), out.dtype, 'TRSF out cust')
111
  logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(self.n_q)], dim=1)#[2*bs,4,1,2048]
112
+
113
  logits = 3 * logits[:bs, :, :, :] - 2 * logits[bs:, :, :, :] # [3, 4, 1, 2048]
114
+ print(logits.sum(0).sum(2), logits.shape, 'SAMPL custom')
115
 
116
  # SAMPLE TOP K
117
+ k = 1#400 # 450 is nice sound still train honk is clear!
118
  p = torch.softmax(logits, dim=3)
119
  top_k_value, _ = torch.topk(p, k, dim=3) # [3, 4, 1, k]
120
  min_value_top_k = top_k_value[:, :, :, -1:]
 
125
  p = p.reshape(bs * self.n_q, 2048)
126
  out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
127
  num_samples=self.n_draw,
128
+ replacement=True) # [bs*4, self.n_draw]
129
  # print('DRAW','c', out)
130
  return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
131
 
132
  @torch.no_grad()
133
  def generate(self,
134
+ max_tokens=None,
135
+ text_condition=None
136
+ ):
137
+ x = self.t5(text_condition)
138
+ bs = x.shape[0] // 2 # has null conditions - bs*2*N_REPEAT applys in builders.py
 
 
 
 
 
139
  out_codes = torch.full((bs,
140
  self.n_draw,
141
  4,
142
  4 + 3 + max_tokens), # 4 + max_tokens + 4-1 to have sufficient to index the 1st antidiagonal of 4x4 + 4 xtra tokens
143
  self.card,
144
  dtype=torch.long,
145
+ device=x.device) # [bs, n_draw, 4, dur]
146
  # =========================================
147
  for offset in range(0, max_tokens + 4 - 1): # max_tokens + n_q - 1
148
 
149
  # extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
150
  next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
151
  #gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
152
+ condition_tensors=x, # utilisation of the attention mask of txt condition ?
153
  token_count=offset) # [bs, n_draw, 4]
154
 
155
  # Fill of next_token should be also placed on antidiagonal [not column]
 
191
  pass #print('No delete anti-diag')
192
 
193
  out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
194
+ print('\nFULL FINAL TOKENS UNFILT\n', out_codes[:, 0, :, 4:max_tokens+4], out_codes[0, 0, :, 4:max_tokens+4].shape)
195
  # EXTRACT COLUMNS AS ALIGN IS ALREADY DONE by FILLING DIAGONALLY
196
  out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens) # [bs, 4, duration*n_draw] DISCARD FILL 2048
197
 
audiocraft/seanet.py CHANGED
@@ -82,21 +82,36 @@ class SEANetResnetBlock(nn.Module):
82
 
83
 
84
  class SEANetDecoder(nn.Module):
85
-
86
- def __init__(self, channels: int = 1,
87
- dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
88
- ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU',
 
 
 
 
 
 
 
 
 
 
89
  activation_params: dict = {'alpha': 1.0},
90
- final_activation: tp.Optional[str] = None,
91
- final_activation_params: tp.Optional[dict] = None,
92
- norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {},
93
- kernel_size: int = 7,
94
- last_kernel_size: int = 7, residual_kernel_size: int = 3,
95
- dilation_base: int = 2, causal: bool = False,
96
- pad_mode: str = 'reflect', true_skip: bool = True,
97
- compress: int = 2, lstm: int = 0,
98
- disable_norm_outer_blocks: int = 0,
99
- trim_right_ratio: float = 1.0):
 
 
 
 
 
100
  super().__init__()
101
  self.dimension = dimension
102
  self.channels = channels
 
82
 
83
 
84
  class SEANetDecoder(nn.Module):
85
+ # channels=1 dimension=128 n_filters=64 n_residual_layers=1 ratios=[8, 5, 4, 2]
86
+ # activation='ELU' activation_params={'alpha': 1.0}, final_activation=None
87
+ # final_activation_params=None norm='weight_norm'
88
+ # norm_params={} kernel_size=7 last_kernel_size=7 residual_kernel_size=3 dilation_base=2
89
+ # causal=False pad_mode='constant'
90
+ # true_skip=True compress=2 lstm=2 disable_norm_outer_blocks=0 trim_right_ratio=1.0
91
+
92
+ def __init__(self,
93
+ channels = 1,
94
+ dimension = 128,
95
+ n_filters = 64,
96
+ n_residual_layers = 1,
97
+ ratios = [8, 5, 4, 2],
98
+ activation = 'ELU',
99
  activation_params: dict = {'alpha': 1.0},
100
+ final_activation = None,
101
+ final_activation_params = None,
102
+ norm = 'weight_norm',
103
+ norm_params = {},
104
+ kernel_size = 7,
105
+ last_kernel_size = 7,
106
+ residual_kernel_size = 3,
107
+ dilation_base = 2,
108
+ causal = False,
109
+ pad_mode = 'constant',
110
+ true_skip = True,
111
+ compress = 2,
112
+ lstm = 2,
113
+ disable_norm_outer_blocks = 0,
114
+ trim_right_ratio = 1.0):
115
  super().__init__()
116
  self.dimension = dimension
117
  self.channels = channels
audiocraft/transformer.py CHANGED
@@ -53,16 +53,13 @@ class StreamingMultiheadAttention(nn.Module):
53
  # 1st projected makes k,v (instantaneous)
54
  # Here else is self_attention for audio with itself (above is cross attention txt)
55
 
56
-
57
  # HISTORY - DIFFERENT FOR EACH TRANSF LAYER
58
 
59
- projected = nn.functional.linear(query, self.in_proj_weight)
60
 
61
  bound_layout = "b h p t d"
62
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
63
  q, k, v = packed.unbind(dim=2)
64
-
65
-
66
  if self.k_history is not None:
67
  # flush
68
  if self.k_history.shape[2] > 71:
@@ -84,7 +81,7 @@ class StreamingMultiheadAttention(nn.Module):
84
 
85
 
86
  # KV COMPLETION ONLY ON SELF ATTENTION
87
-
88
  x = torch.nn.functional.scaled_dot_product_attention(
89
  q, k, v, is_causal=False, dropout_p=0
90
  )
@@ -93,15 +90,24 @@ class StreamingMultiheadAttention(nn.Module):
93
  return x
94
 
95
 
96
- class StreamingTransformerLayer(nn.Module):
97
 
 
 
98
  def __init__(self,
99
  d_model,
100
  num_heads,
101
  dim_feedforward):
102
-
103
-
104
- super().__init__()
 
 
 
 
 
 
 
 
105
 
106
  self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
107
  num_heads=num_heads)
@@ -116,20 +122,13 @@ class StreamingTransformerLayer(nn.Module):
116
 
117
 
118
  def forward(self,
119
- src,
120
  cross_attention_src=None): # txtcond
121
- '''T is saved float16 weights - should we cast src to float16'''
122
-
123
- x = src
124
-
125
  x = x + self.self_attn(self.norm1(x))
126
-
127
- if cross_attention_src is not None:
128
- x = x + self.cross_attention(
129
- query = self.norm_cross(x),
130
- key = cross_attention_src,
131
- value = cross_attention_src) # txtcondition
132
-
133
  x = x + self.linear2(F.gelu(self.linear1( self.norm2(x) )))
134
  return x
135
 
@@ -164,9 +163,12 @@ class StreamingTransformer(nn.Module):
164
  x,
165
  token_count=None,
166
  cross_attention_src=None):
167
-
 
168
  if self.positional_embedding in ['sin', 'sin_rope']:
169
- pos_emb = create_sin_embedding(torch.tensor([[[.0]], [[.0]]], device=x.device) + token_count, x.shape[2], max_period=self.max_period)
 
 
170
 
171
  x = x + pos_emb
172
  for j, lay in enumerate(self.layers):
 
53
  # 1st projected makes k,v (instantaneous)
54
  # Here else is self_attention for audio with itself (above is cross attention txt)
55
 
 
56
  # HISTORY - DIFFERENT FOR EACH TRANSF LAYER
57
 
58
+ projected = nn.functional.linear(query, self.in_proj_weight, None) # here we have different floating values from official
59
 
60
  bound_layout = "b h p t d"
61
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
62
  q, k, v = packed.unbind(dim=2)
 
 
63
  if self.k_history is not None:
64
  # flush
65
  if self.k_history.shape[2] > 71:
 
81
 
82
 
83
  # KV COMPLETION ONLY ON SELF ATTENTION
84
+
85
  x = torch.nn.functional.scaled_dot_product_attention(
86
  q, k, v, is_causal=False, dropout_p=0
87
  )
 
90
  return x
91
 
92
 
 
93
 
94
+ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
95
+
96
  def __init__(self,
97
  d_model,
98
  num_heads,
99
  dim_feedforward):
100
+
101
+ super().__init__(d_model,
102
+ num_heads,
103
+ dim_feedforward=dim_feedforward,
104
+ dropout=0.0,
105
+ device='cuda',
106
+ dtype=torch.float32,
107
+ batch_first=True,
108
+ norm_first=True,
109
+ activation='gelu')
110
+ # super().__init__()
111
 
112
  self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
113
  num_heads=num_heads)
 
122
 
123
 
124
  def forward(self,
125
+ x,
126
  cross_attention_src=None): # txtcond
127
+ # x = src
 
 
 
128
  x = x + self.self_attn(self.norm1(x))
129
+ x = x + self.cross_attention(query = self.norm_cross(x),
130
+ key = cross_attention_src,
131
+ value = cross_attention_src) # txtcondition
 
 
 
 
132
  x = x + self.linear2(F.gelu(self.linear1( self.norm2(x) )))
133
  return x
134
 
 
163
  x,
164
  token_count=None,
165
  cross_attention_src=None):
166
+ #cross_attention_src = (torch.arange(1536)[None, None, :] * torch.arange(6).reshape(2,3,1)).to(torch.float).to(x.device)
167
+ print(x.sum(), cross_attention_src[0, :, :].sum(), cross_attention_src[1, :, :].sum(), x.dtype, cross_attention_src.dtype, 'Xattnc')
168
  if self.positional_embedding in ['sin', 'sin_rope']:
169
+ pos_emb = create_sin_embedding(torch.zeros(x.shape[0], 1, 1, device=x.device) + token_count,
170
+ 1536,
171
+ max_period=self.max_period)
172
 
173
  x = x + pos_emb
174
  for j, lay in enumerate(self.layers):
audiocraft/vq.py CHANGED
@@ -145,46 +145,27 @@ class ResidualVectorQuantization(nn.Module):
145
  layer = self.layers[i]
146
  quantized = layer.decode(indices)
147
  quantized_out = quantized_out + quantized
148
- return quantized_out
149
-
150
-
151
-
152
-
153
- # ------------------------------------- END core_vq.py
154
 
155
 
156
  class ResidualVectorQuantizer(nn.Module):
157
- """Residual Vector Quantizer.
158
-
159
- Args:
160
- dimension (int): Dimension of the codebooks.
161
- n_q (int): Number of residual vector quantizers used.
162
- q_dropout (bool): Random quantizer drop out at train time.
163
- bins (int): Codebook size.
164
- decay (float): Decay for exponential moving average over the codebooks.
165
- kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
166
- kmeans_iters (int): Number of iterations used for kmeans initialization.
167
- threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
168
- that have an exponential moving average cluster size less than the specified threshold with
169
- randomly selected vector from the current batch.
170
- orthogonal_reg_weight (float): Orthogonal regularization weights.
171
- orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
172
- orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
173
- for orthogonal regularization.
174
- """
175
  def __init__(
176
  self,
177
- dimension: int = 256,
178
- n_q: int = 8,
179
- q_dropout: bool = False,
180
- bins: int = 1024,
181
- decay: float = 0.99,
182
- kmeans_init: bool = True,
183
- kmeans_iters: int = 10,
184
- threshold_ema_dead_code: int = 2,
185
- orthogonal_reg_weight: float = 0.0,
186
- orthogonal_reg_active_codes_only: bool = False,
187
- orthogonal_reg_max_codes: tp.Optional[int] = None,
188
  ):
189
  super().__init__()
190
  self.max_n_q = n_q
 
145
  layer = self.layers[i]
146
  quantized = layer.decode(indices)
147
  quantized_out = quantized_out + quantized
148
+ return quantized_out
 
 
 
 
 
149
 
150
 
151
  class ResidualVectorQuantizer(nn.Module):
152
+
153
+ # dimension=128 n_q=4 q_dropout=False bins=2048 decay=0.99 kmeans_init=True kmeans_iters=50 threshold_ema_dead_code=2
154
+ # orthogonal_reg_weight=0.0 orthogonal_reg_active_codes_only=False orthogonal_reg_max_codes=None
155
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def __init__(
157
  self,
158
+ dimension = 128,
159
+ n_q = 4,
160
+ q_dropout = False,
161
+ bins = 2048,
162
+ decay = 0.99,
163
+ kmeans_init = True,
164
+ kmeans_iters = 50,
165
+ threshold_ema_dead_code = 2,
166
+ orthogonal_reg_weight = 0.0,
167
+ orthogonal_reg_active_codes_only = False,
168
+ orthogonal_reg_max_codes = None,
169
  ):
170
  super().__init__()
171
  self.max_n_q = n_q