Dionyssos commited on
Commit
4eabff6
·
1 Parent(s): 0230db1
audiocraft/builders.py CHANGED
@@ -7,47 +7,36 @@ from huggingface_hub import hf_hub_download
7
  import os
8
  from omegaconf import OmegaConf
9
  from .encodec import EncodecModel
10
- from .lm import LMModel, TorchAutocast
11
  from .seanet import SEANetDecoder
12
  from .vq import ResidualVectorQuantizer
13
 
14
-
15
- N_REPEAT = 4 # num (virtual batch_size) clones of audio sounds
16
 
17
  def _shift(x):
18
- n = x.shape[0]
19
  offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
20
- return torch.roll(x, offset, dims=0)
21
-
22
-
23
- def _delete_param(cfg, full_name):
24
- parts = full_name.split('.')
25
- for part in parts[:-1]:
26
- if part in cfg:
27
- cfg = cfg[part]
28
- else:
29
- return
30
- OmegaConf.set_struct(cfg, False)
31
- if parts[-1] in cfg:
32
- del cfg[parts[-1]]
33
- OmegaConf.set_struct(cfg, True)
34
-
35
 
36
  class AudioGen(torch.nn.Module):
37
-
38
  # https://huggingface.co/facebook/audiogen-medium
39
-
40
  def __init__(self):
41
 
42
  super().__init__()
43
- self.autocast = TorchAutocast(
44
- enabled=True, device_type='cuda', dtype=torch.float16)
45
  # Vocoder
46
  _file_1 = hf_hub_download(
47
- repo_id='facebook/audiogen-medium',
48
  filename="compression_state_dict.bin",
49
  cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
50
- library_name="audiocraft",
51
  library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
52
  pkg = torch.load(_file_1, map_location='cpu')
53
  # kwargs = OmegaConf.create(pkg['xp.cfg'])
@@ -66,58 +55,47 @@ class AudioGen(torch.nn.Module):
66
  self.resample_fn = torchaudio.transforms.Resample(16000, 24000) # AudioGen = 16KHZ StyleTTS2 = 24 KHz / MMSTTS = 24 KHz
67
  # # T5 &
68
  # LM
69
-
70
  _file_2 = hf_hub_download(
71
  repo_id='facebook/audiogen-medium',
72
- filename="state_dict.bin",
73
  cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
74
- library_name="audiocraft",
75
  library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
76
  pkg = torch.load(_file_2, map_location='cpu')
77
  cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
78
- _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
79
- _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
80
- _delete_param(cfg, 'conditioners.args.drop_desc_p')
81
- # print('___________________________CFG___________________',cfg,'\n=======================')
82
- # kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
83
- # print('___________________________Kwarg___________________',kwargs,'\n=======================')
84
  _best = pkg['best_state']
85
  # _best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
86
- # _best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
87
  _best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float)
88
  _best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float)
89
  self.lm = LMModel() #to(torch.float16)
90
- self.lm.load_state_dict(pkg['best_state'],
91
  strict=True)
92
  #
93
  self.lm.eval()
94
  self.compression_model.eval()
95
 
 
96
  def generate(self,
97
  prompt='dogs mewo',
98
  duration=2.24, ## seconds of audio
99
  ):
100
-
101
- with torch.no_grad():
102
-
103
- with self.autocast:
104
- # LM
105
- gen_tokens = self.lm.generate(
106
- text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT, # '' for null condition, # ['trance', 'dogs meow', '', '']
107
- max_tokens=int(duration / (N_REPEAT * self.lm.n_draw) * self.compression_model.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
108
-
109
- x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
110
 
111
- x = x[:, 0, :] # last samples have splash sounds DISCARD 25000 last samples
112
 
113
- # AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
114
-
115
- x = self.resample_fn(x) # [N_REPEAT, duration]
116
 
117
- x = x.reshape(-1)
118
-
119
- # for _ in range(7):
120
- # x = _shift(x)
121
 
122
- print(x.abs().max(), 'MAX')
123
- return x / (x.abs().max() + 1e-7)
 
 
7
  import os
8
  from omegaconf import OmegaConf
9
  from .encodec import EncodecModel
10
+ from .lm import LMModel
11
  from .seanet import SEANetDecoder
12
  from .vq import ResidualVectorQuantizer
13
 
14
+ # torch.backends.cudnn.deterministic = True
15
+ N_REPEAT = 2 # num (virtual batch_size) clones of audio sounds
16
 
17
  def _shift(x):
18
+ n = len(x)
19
  offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
20
+ if isinstance(x, torch.Tensor):
21
+ return torch.roll(x, offset, dims=0)
22
+ elif isinstance(x, str):
23
+ return x[offset:] + x[:offset] #np.roll(x, offset)
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  class AudioGen(torch.nn.Module):
26
+
27
  # https://huggingface.co/facebook/audiogen-medium
28
+
29
  def __init__(self):
30
 
31
  super().__init__()
32
+ # self.autocast = TorchAutocast(
33
+ # enabled=True, device_type='cuda', dtype=torch.float16)
34
  # Vocoder
35
  _file_1 = hf_hub_download(
36
+ repo_id='facebook/audiogen-medium',
37
  filename="compression_state_dict.bin",
38
  cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
39
+ library_name="audiocraft",
40
  library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
41
  pkg = torch.load(_file_1, map_location='cpu')
42
  # kwargs = OmegaConf.create(pkg['xp.cfg'])
 
55
  self.resample_fn = torchaudio.transforms.Resample(16000, 24000) # AudioGen = 16KHZ StyleTTS2 = 24 KHz / MMSTTS = 24 KHz
56
  # # T5 &
57
  # LM
58
+
59
  _file_2 = hf_hub_download(
60
  repo_id='facebook/audiogen-medium',
61
+ filename="state_dict.bin",
62
  cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
63
+ library_name="audiocraft",
64
  library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
65
  pkg = torch.load(_file_2, map_location='cpu')
66
  cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
 
 
 
 
 
 
67
  _best = pkg['best_state']
68
  # _best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
69
+ # _best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
70
  _best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float)
71
  _best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float)
72
  self.lm = LMModel() #to(torch.float16)
73
+ self.lm.load_state_dict(pkg['best_state'],
74
  strict=True)
75
  #
76
  self.lm.eval()
77
  self.compression_model.eval()
78
 
79
+ @torch.no_grad()
80
  def generate(self,
81
  prompt='dogs mewo',
82
  duration=2.24, ## seconds of audio
83
  ):
84
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
85
+ gen_tokens = self.lm.generate(
86
+ text_condition=[prompt] + [prompt[:10] + _shift(prompt) for _ in range(N_REPEAT-1)] + [''] * N_REPEAT, # '' for null condition, # ['trance', 'dogs meow', '', '']
87
+ max_tokens=int(duration / (N_REPEAT * self.lm.n_draw) * self.compression_model.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
88
+
89
+ x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
90
+
91
+ x = x[:, 0, :] # last samples have splash sounds DISCARD 25000 last samples
 
 
92
 
93
+ # AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
94
 
95
+ x = self.resample_fn(x) # [N_REPEAT, duration]
 
 
96
 
97
+ x = x.reshape(-1)
 
 
 
98
 
99
+ # for _ in range(7):
100
+ # x = _shift(x)
101
+ return x #x / (x.abs().max() + 1e-7)
audiocraft/encodec.py CHANGED
@@ -1,5 +1,4 @@
1
  import typing as tp
2
- from einops import rearrange
3
  import numpy as np
4
  import torch
5
  from torch import nn
 
1
  import typing as tp
 
2
  import numpy as np
3
  import torch
4
  from torch import nn
audiocraft/lm.py CHANGED
@@ -3,40 +3,13 @@ from audiocraft.transformer import StreamingTransformer
3
  from torch import nn
4
  from transformers import T5EncoderModel, T5Tokenizer # type: ignore
5
 
6
- class TorchAutocast:
7
-
8
- def __init__(self, enabled: bool, *args, **kwargs):
9
- self.autocast = torch.autocast(*args, **kwargs) if enabled else None
10
-
11
- def __enter__(self):
12
- if self.autocast is None:
13
- return
14
- try:
15
- self.autocast.__enter__()
16
- except RuntimeError:
17
- device = self.autocast.device
18
- dtype = self.autocast.fast_dtype
19
- raise RuntimeError(
20
- f"There was an error autocasting with dtype={dtype} device={device}\n"
21
- "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
22
- )
23
-
24
- def __exit__(self, *args, **kwargs):
25
- if self.autocast is None:
26
- return
27
- self.autocast.__exit__(*args, **kwargs)
28
-
29
-
30
  class T5(nn.Module):
31
 
32
  def __init__(self):
33
  # run this from within lm so it autocasts thus match exact values of t5 in official audiogen
34
  super().__init__()
35
-
36
- self.dim = 1024
37
- self.output_dim = 1536
38
- self.output_proj = nn.Linear(self.dim, self.output_dim)
39
- self.autocast = TorchAutocast(enabled=True, device_type='cuda', dtype=torch.float)
40
  self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True)
41
  t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False)
42
 
@@ -45,26 +18,24 @@ class T5(nn.Module):
45
  self.__dict__['t5'] = t5.to('cuda:0')
46
 
47
  def forward(self, prompt):
48
- with torch.set_grad_enabled(False), self.autocast:
49
 
50
  bs = len(prompt) // 2
51
-
52
- # txt 2 hidden
53
-
54
  d = self.t5_tokenizer(prompt,
55
  return_tensors='pt',
56
  padding=True).to(self.output_proj.bias.device)
57
  d['attention_mask'][bs:, :] = 0 # null condition t5 attn_mask should be zero
58
 
59
- x = self.t5(input_ids=d['input_ids'],
60
  attention_mask=d['attention_mask']).last_hidden_state # no kv
61
-
62
- # output_proj as float32
63
  print('BEF PROJ',x[0, :, :].sum(), x[1, :, :].sum(), self.output_proj.weight.sum(), self.output_proj.weight.dtype, self.output_proj.bias.sum(), 'GEN\n\n143')
64
  x = self.output_proj(x) # nn.Linear() - produces different result if there is no duplicate txt condition here
65
  x[bs:, :, :] = 0 # venv/../site-packages/audiocraft/modules/conditioners.py -> tokenize()
66
  return x
67
 
 
68
  class LMModel(nn.Module):
69
 
70
  def __init__(self,
@@ -77,9 +48,9 @@ class LMModel(nn.Module):
77
  super().__init__()
78
  self.t5 = T5()
79
  self.card = card # 2048 ?
80
- self.n_draw = 6 # replicate so many times the generation of each text in batch
81
- # the batch is more expensive than n_draw as it re-runs the model bs times
82
- # n_draw just draws more phonemes from the multinomial - after running the lm
83
  embed_dim = self.card + 1
84
  self.n_q = n_q
85
  self.dim = dim
@@ -107,14 +78,13 @@ class LMModel(nn.Module):
107
  cross_attention_src=condition_tensors,
108
  token_count=token_count
109
  )
110
- print(out.sum(), out.dtype, 'TRSF out cust')
111
  logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(self.n_q)], dim=1)#[2*bs,4,1,2048]
112
 
113
  logits = 3 * logits[:bs, :, :, :] - 2 * logits[bs:, :, :, :] # [3, 4, 1, 2048]
114
- print(logits.sum(0).sum(2), logits.shape, 'SAMPL custom')
115
 
116
  # SAMPLE TOP K
117
- k = 1#400 # 450 is nice sound still train honk is clear!
118
  p = torch.softmax(logits, dim=3)
119
  top_k_value, _ = torch.topk(p, k, dim=3) # [3, 4, 1, k]
120
  min_value_top_k = top_k_value[:, :, :, -1:]
@@ -125,7 +95,7 @@ class LMModel(nn.Module):
125
  p = p.reshape(bs * self.n_q, 2048)
126
  out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
127
  num_samples=self.n_draw,
128
- replacement=True) # [bs*4, self.n_draw]
129
  # print('DRAW','c', out)
130
  return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
131
 
 
3
  from torch import nn
4
  from transformers import T5EncoderModel, T5Tokenizer # type: ignore
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class T5(nn.Module):
7
 
8
  def __init__(self):
9
  # run this from within lm so it autocasts thus match exact values of t5 in official audiogen
10
  super().__init__()
11
+ self.output_proj = nn.Linear(1024, # t5-large
12
+ 1536) # lm hidden
 
 
 
13
  self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True)
14
  t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False)
15
 
 
18
  self.__dict__['t5'] = t5.to('cuda:0')
19
 
20
  def forward(self, prompt):
21
+ with torch.set_grad_enabled(False), torch.autocast(device_type='cuda', dtype=torch.float32):
22
 
23
  bs = len(prompt) // 2
 
 
 
24
  d = self.t5_tokenizer(prompt,
25
  return_tensors='pt',
26
  padding=True).to(self.output_proj.bias.device)
27
  d['attention_mask'][bs:, :] = 0 # null condition t5 attn_mask should be zero
28
 
29
+ x = self.t5(input_ids=d['input_ids'],
30
  attention_mask=d['attention_mask']).last_hidden_state # no kv
31
+
32
+ # output_proj as float32
33
  print('BEF PROJ',x[0, :, :].sum(), x[1, :, :].sum(), self.output_proj.weight.sum(), self.output_proj.weight.dtype, self.output_proj.bias.sum(), 'GEN\n\n143')
34
  x = self.output_proj(x) # nn.Linear() - produces different result if there is no duplicate txt condition here
35
  x[bs:, :, :] = 0 # venv/../site-packages/audiocraft/modules/conditioners.py -> tokenize()
36
  return x
37
 
38
+
39
  class LMModel(nn.Module):
40
 
41
  def __init__(self,
 
48
  super().__init__()
49
  self.t5 = T5()
50
  self.card = card # 2048 ?
51
+ self.n_draw = 1 # draw additional tokens at each call:
52
+ # Batch size is slower than n_draw as it calls the transformer on larger batch
53
+ # n_draw instead draws more tokens/phonemes from torch.multinomial - after execution of lm
54
  embed_dim = self.card + 1
55
  self.n_q = n_q
56
  self.dim = dim
 
78
  cross_attention_src=condition_tensors,
79
  token_count=token_count
80
  )
81
+
82
  logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(self.n_q)], dim=1)#[2*bs,4,1,2048]
83
 
84
  logits = 3 * logits[:bs, :, :, :] - 2 * logits[bs:, :, :, :] # [3, 4, 1, 2048]
 
85
 
86
  # SAMPLE TOP K
87
+ k = 400 # 450 is nice sound still train honk is clear!
88
  p = torch.softmax(logits, dim=3)
89
  top_k_value, _ = torch.topk(p, k, dim=3) # [3, 4, 1, k]
90
  min_value_top_k = top_k_value[:, :, :, -1:]
 
95
  p = p.reshape(bs * self.n_q, 2048)
96
  out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
97
  num_samples=self.n_draw,
98
+ replacement=False) # [bs*4, self.n_draw]
99
  # print('DRAW','c', out)
100
  return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
101
 
audiocraft/transformer.py CHANGED
@@ -3,7 +3,12 @@ import torch.nn as nn
3
  from torch.nn import functional as F
4
  from einops import rearrange
5
 
6
- def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000):
 
 
 
 
 
7
  assert dim % 2 == 0
8
  half_dim = dim // 2
9
  positions = positions.to(torch.float)
@@ -48,7 +53,7 @@ class StreamingMultiheadAttention(nn.Module):
48
  v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
49
 
50
  q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
51
- # print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(),'CROSS A5')
52
  else:
53
  # 1st projected makes k,v (instantaneous)
54
  # Here else is self_attention for audio with itself (above is cross attention txt)
@@ -56,7 +61,7 @@ class StreamingMultiheadAttention(nn.Module):
56
  # HISTORY - DIFFERENT FOR EACH TRANSF LAYER
57
 
58
  projected = nn.functional.linear(query, self.in_proj_weight, None) # here we have different floating values from official
59
-
60
  bound_layout = "b h p t d"
61
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
62
  q, k, v = packed.unbind(dim=2)
@@ -83,8 +88,8 @@ class StreamingMultiheadAttention(nn.Module):
83
  # KV COMPLETION ONLY ON SELF ATTENTION
84
 
85
  x = torch.nn.functional.scaled_dot_product_attention(
86
- q, k, v, is_causal=False, dropout_p=0
87
- )
88
  x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
89
  x = self.out_proj(x)
90
  return x
@@ -124,7 +129,6 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
124
  def forward(self,
125
  x,
126
  cross_attention_src=None): # txtcond
127
- # x = src
128
  x = x + self.self_attn(self.norm1(x))
129
  x = x + self.cross_attention(query = self.norm_cross(x),
130
  key = cross_attention_src,
@@ -163,8 +167,7 @@ class StreamingTransformer(nn.Module):
163
  x,
164
  token_count=None,
165
  cross_attention_src=None):
166
- #cross_attention_src = (torch.arange(1536)[None, None, :] * torch.arange(6).reshape(2,3,1)).to(torch.float).to(x.device)
167
- print(x.sum(), cross_attention_src[0, :, :].sum(), cross_attention_src[1, :, :].sum(), x.dtype, cross_attention_src.dtype, 'Xattnc')
168
  if self.positional_embedding in ['sin', 'sin_rope']:
169
  pos_emb = create_sin_embedding(torch.zeros(x.shape[0], 1, 1, device=x.device) + token_count,
170
  1536,
 
3
  from torch.nn import functional as F
4
  from einops import rearrange
5
 
6
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
7
+
8
+ def create_sin_embedding(positions,
9
+ dim,
10
+ max_period=10000
11
+ ):
12
  assert dim % 2 == 0
13
  half_dim = dim // 2
14
  positions = positions.to(torch.float)
 
53
  v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
54
 
55
  q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
56
+
57
  else:
58
  # 1st projected makes k,v (instantaneous)
59
  # Here else is self_attention for audio with itself (above is cross attention txt)
 
61
  # HISTORY - DIFFERENT FOR EACH TRANSF LAYER
62
 
63
  projected = nn.functional.linear(query, self.in_proj_weight, None) # here we have different floating values from official
64
+ # print(query.sum(), projected.sum() , self.in_proj_weight.sum(), 'Lc') # verified official AudioGen values
65
  bound_layout = "b h p t d"
66
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
67
  q, k, v = packed.unbind(dim=2)
 
88
  # KV COMPLETION ONLY ON SELF ATTENTION
89
 
90
  x = torch.nn.functional.scaled_dot_product_attention(
91
+ q, k, v, is_causal=False, dropout_p=0)
92
+
93
  x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
94
  x = self.out_proj(x)
95
  return x
 
129
  def forward(self,
130
  x,
131
  cross_attention_src=None): # txtcond
 
132
  x = x + self.self_attn(self.norm1(x))
133
  x = x + self.cross_attention(query = self.norm_cross(x),
134
  key = cross_attention_src,
 
167
  x,
168
  token_count=None,
169
  cross_attention_src=None):
170
+
 
171
  if self.positional_embedding in ['sin', 'sin_rope']:
172
  pos_emb = create_sin_embedding(torch.zeros(x.shape[0], 1, 1, device=x.device) + token_count,
173
  1536,
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- torch
2
- torchaudio
3
- numpy
4
  audiofile
5
  num2words
 
6
  cached_path
7
  einops
8
  flask
@@ -12,9 +12,10 @@ sentencepiece
12
  omegaconf
13
  opencv-python
14
  soundfile
15
- transformers
16
  audresample
17
  srt
18
  nltk
19
  phonemizer
20
  docx
 
 
1
+ torch==2.1.0
2
+ numpy<2.0.0
 
3
  audiofile
4
  num2words
5
+ huggingface_hub
6
  cached_path
7
  einops
8
  flask
 
12
  omegaconf
13
  opencv-python
14
  soundfile
15
+ transformers==4.49.0
16
  audresample
17
  srt
18
  nltk
19
  phonemizer
20
  docx
21
+ torchaudio