4.49
Browse files- audiocraft/builders.py +34 -56
- audiocraft/encodec.py +0 -1
- audiocraft/lm.py +13 -43
- audiocraft/transformer.py +11 -8
- requirements.txt +5 -4
audiocraft/builders.py
CHANGED
@@ -7,47 +7,36 @@ from huggingface_hub import hf_hub_download
|
|
7 |
import os
|
8 |
from omegaconf import OmegaConf
|
9 |
from .encodec import EncodecModel
|
10 |
-
from .lm import LMModel
|
11 |
from .seanet import SEANetDecoder
|
12 |
from .vq import ResidualVectorQuantizer
|
13 |
|
14 |
-
|
15 |
-
N_REPEAT =
|
16 |
|
17 |
def _shift(x):
|
18 |
-
n = x
|
19 |
offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
parts = full_name.split('.')
|
25 |
-
for part in parts[:-1]:
|
26 |
-
if part in cfg:
|
27 |
-
cfg = cfg[part]
|
28 |
-
else:
|
29 |
-
return
|
30 |
-
OmegaConf.set_struct(cfg, False)
|
31 |
-
if parts[-1] in cfg:
|
32 |
-
del cfg[parts[-1]]
|
33 |
-
OmegaConf.set_struct(cfg, True)
|
34 |
-
|
35 |
|
36 |
class AudioGen(torch.nn.Module):
|
37 |
-
|
38 |
# https://huggingface.co/facebook/audiogen-medium
|
39 |
-
|
40 |
def __init__(self):
|
41 |
|
42 |
super().__init__()
|
43 |
-
self.autocast = TorchAutocast(
|
44 |
-
|
45 |
# Vocoder
|
46 |
_file_1 = hf_hub_download(
|
47 |
-
repo_id='facebook/audiogen-medium',
|
48 |
filename="compression_state_dict.bin",
|
49 |
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
|
50 |
-
library_name="audiocraft",
|
51 |
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
|
52 |
pkg = torch.load(_file_1, map_location='cpu')
|
53 |
# kwargs = OmegaConf.create(pkg['xp.cfg'])
|
@@ -66,58 +55,47 @@ class AudioGen(torch.nn.Module):
|
|
66 |
self.resample_fn = torchaudio.transforms.Resample(16000, 24000) # AudioGen = 16KHZ StyleTTS2 = 24 KHz / MMSTTS = 24 KHz
|
67 |
# # T5 &
|
68 |
# LM
|
69 |
-
|
70 |
_file_2 = hf_hub_download(
|
71 |
repo_id='facebook/audiogen-medium',
|
72 |
-
filename="state_dict.bin",
|
73 |
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
|
74 |
-
library_name="audiocraft",
|
75 |
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
|
76 |
pkg = torch.load(_file_2, map_location='cpu')
|
77 |
cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
|
78 |
-
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
|
79 |
-
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
80 |
-
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
81 |
-
# print('___________________________CFG___________________',cfg,'\n=======================')
|
82 |
-
# kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
|
83 |
-
# print('___________________________Kwarg___________________',kwargs,'\n=======================')
|
84 |
_best = pkg['best_state']
|
85 |
# _best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
|
86 |
-
# _best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
|
87 |
_best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float)
|
88 |
_best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float)
|
89 |
self.lm = LMModel() #to(torch.float16)
|
90 |
-
self.lm.load_state_dict(pkg['best_state'],
|
91 |
strict=True)
|
92 |
#
|
93 |
self.lm.eval()
|
94 |
self.compression_model.eval()
|
95 |
|
|
|
96 |
def generate(self,
|
97 |
prompt='dogs mewo',
|
98 |
duration=2.24, ## seconds of audio
|
99 |
):
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
|
110 |
|
111 |
-
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
x = self.resample_fn(x) # [N_REPEAT, duration]
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
# for _ in range(7):
|
120 |
-
# x = _shift(x)
|
121 |
|
122 |
-
|
123 |
-
|
|
|
|
7 |
import os
|
8 |
from omegaconf import OmegaConf
|
9 |
from .encodec import EncodecModel
|
10 |
+
from .lm import LMModel
|
11 |
from .seanet import SEANetDecoder
|
12 |
from .vq import ResidualVectorQuantizer
|
13 |
|
14 |
+
# torch.backends.cudnn.deterministic = True
|
15 |
+
N_REPEAT = 2 # num (virtual batch_size) clones of audio sounds
|
16 |
|
17 |
def _shift(x):
|
18 |
+
n = len(x)
|
19 |
offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
|
20 |
+
if isinstance(x, torch.Tensor):
|
21 |
+
return torch.roll(x, offset, dims=0)
|
22 |
+
elif isinstance(x, str):
|
23 |
+
return x[offset:] + x[:offset] #np.roll(x, offset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
class AudioGen(torch.nn.Module):
|
26 |
+
|
27 |
# https://huggingface.co/facebook/audiogen-medium
|
28 |
+
|
29 |
def __init__(self):
|
30 |
|
31 |
super().__init__()
|
32 |
+
# self.autocast = TorchAutocast(
|
33 |
+
# enabled=True, device_type='cuda', dtype=torch.float16)
|
34 |
# Vocoder
|
35 |
_file_1 = hf_hub_download(
|
36 |
+
repo_id='facebook/audiogen-medium',
|
37 |
filename="compression_state_dict.bin",
|
38 |
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
|
39 |
+
library_name="audiocraft",
|
40 |
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
|
41 |
pkg = torch.load(_file_1, map_location='cpu')
|
42 |
# kwargs = OmegaConf.create(pkg['xp.cfg'])
|
|
|
55 |
self.resample_fn = torchaudio.transforms.Resample(16000, 24000) # AudioGen = 16KHZ StyleTTS2 = 24 KHz / MMSTTS = 24 KHz
|
56 |
# # T5 &
|
57 |
# LM
|
58 |
+
|
59 |
_file_2 = hf_hub_download(
|
60 |
repo_id='facebook/audiogen-medium',
|
61 |
+
filename="state_dict.bin",
|
62 |
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
|
63 |
+
library_name="audiocraft",
|
64 |
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
|
65 |
pkg = torch.load(_file_2, map_location='cpu')
|
66 |
cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
_best = pkg['best_state']
|
68 |
# _best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
|
69 |
+
# _best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
|
70 |
_best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float)
|
71 |
_best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float)
|
72 |
self.lm = LMModel() #to(torch.float16)
|
73 |
+
self.lm.load_state_dict(pkg['best_state'],
|
74 |
strict=True)
|
75 |
#
|
76 |
self.lm.eval()
|
77 |
self.compression_model.eval()
|
78 |
|
79 |
+
@torch.no_grad()
|
80 |
def generate(self,
|
81 |
prompt='dogs mewo',
|
82 |
duration=2.24, ## seconds of audio
|
83 |
):
|
84 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
85 |
+
gen_tokens = self.lm.generate(
|
86 |
+
text_condition=[prompt] + [prompt[:10] + _shift(prompt) for _ in range(N_REPEAT-1)] + [''] * N_REPEAT, # '' for null condition, # ['trance', 'dogs meow', '', '']
|
87 |
+
max_tokens=int(duration / (N_REPEAT * self.lm.n_draw) * self.compression_model.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
|
88 |
+
|
89 |
+
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
|
90 |
+
|
91 |
+
x = x[:, 0, :] # last samples have splash sounds DISCARD 25000 last samples
|
|
|
|
|
92 |
|
93 |
+
# AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
|
94 |
|
95 |
+
x = self.resample_fn(x) # [N_REPEAT, duration]
|
|
|
|
|
96 |
|
97 |
+
x = x.reshape(-1)
|
|
|
|
|
|
|
98 |
|
99 |
+
# for _ in range(7):
|
100 |
+
# x = _shift(x)
|
101 |
+
return x #x / (x.abs().max() + 1e-7)
|
audiocraft/encodec.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import typing as tp
|
2 |
-
from einops import rearrange
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
from torch import nn
|
|
|
1 |
import typing as tp
|
|
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
from torch import nn
|
audiocraft/lm.py
CHANGED
@@ -3,40 +3,13 @@ from audiocraft.transformer import StreamingTransformer
|
|
3 |
from torch import nn
|
4 |
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
5 |
|
6 |
-
class TorchAutocast:
|
7 |
-
|
8 |
-
def __init__(self, enabled: bool, *args, **kwargs):
|
9 |
-
self.autocast = torch.autocast(*args, **kwargs) if enabled else None
|
10 |
-
|
11 |
-
def __enter__(self):
|
12 |
-
if self.autocast is None:
|
13 |
-
return
|
14 |
-
try:
|
15 |
-
self.autocast.__enter__()
|
16 |
-
except RuntimeError:
|
17 |
-
device = self.autocast.device
|
18 |
-
dtype = self.autocast.fast_dtype
|
19 |
-
raise RuntimeError(
|
20 |
-
f"There was an error autocasting with dtype={dtype} device={device}\n"
|
21 |
-
"If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
|
22 |
-
)
|
23 |
-
|
24 |
-
def __exit__(self, *args, **kwargs):
|
25 |
-
if self.autocast is None:
|
26 |
-
return
|
27 |
-
self.autocast.__exit__(*args, **kwargs)
|
28 |
-
|
29 |
-
|
30 |
class T5(nn.Module):
|
31 |
|
32 |
def __init__(self):
|
33 |
# run this from within lm so it autocasts thus match exact values of t5 in official audiogen
|
34 |
super().__init__()
|
35 |
-
|
36 |
-
|
37 |
-
self.output_dim = 1536
|
38 |
-
self.output_proj = nn.Linear(self.dim, self.output_dim)
|
39 |
-
self.autocast = TorchAutocast(enabled=True, device_type='cuda', dtype=torch.float)
|
40 |
self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True)
|
41 |
t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False)
|
42 |
|
@@ -45,26 +18,24 @@ class T5(nn.Module):
|
|
45 |
self.__dict__['t5'] = t5.to('cuda:0')
|
46 |
|
47 |
def forward(self, prompt):
|
48 |
-
with torch.set_grad_enabled(False),
|
49 |
|
50 |
bs = len(prompt) // 2
|
51 |
-
|
52 |
-
# txt 2 hidden
|
53 |
-
|
54 |
d = self.t5_tokenizer(prompt,
|
55 |
return_tensors='pt',
|
56 |
padding=True).to(self.output_proj.bias.device)
|
57 |
d['attention_mask'][bs:, :] = 0 # null condition t5 attn_mask should be zero
|
58 |
|
59 |
-
x = self.t5(input_ids=d['input_ids'],
|
60 |
attention_mask=d['attention_mask']).last_hidden_state # no kv
|
61 |
-
|
62 |
-
# output_proj as float32
|
63 |
print('BEF PROJ',x[0, :, :].sum(), x[1, :, :].sum(), self.output_proj.weight.sum(), self.output_proj.weight.dtype, self.output_proj.bias.sum(), 'GEN\n\n143')
|
64 |
x = self.output_proj(x) # nn.Linear() - produces different result if there is no duplicate txt condition here
|
65 |
x[bs:, :, :] = 0 # venv/../site-packages/audiocraft/modules/conditioners.py -> tokenize()
|
66 |
return x
|
67 |
|
|
|
68 |
class LMModel(nn.Module):
|
69 |
|
70 |
def __init__(self,
|
@@ -77,9 +48,9 @@ class LMModel(nn.Module):
|
|
77 |
super().__init__()
|
78 |
self.t5 = T5()
|
79 |
self.card = card # 2048 ?
|
80 |
-
self.n_draw =
|
81 |
-
#
|
82 |
-
# n_draw
|
83 |
embed_dim = self.card + 1
|
84 |
self.n_q = n_q
|
85 |
self.dim = dim
|
@@ -107,14 +78,13 @@ class LMModel(nn.Module):
|
|
107 |
cross_attention_src=condition_tensors,
|
108 |
token_count=token_count
|
109 |
)
|
110 |
-
|
111 |
logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(self.n_q)], dim=1)#[2*bs,4,1,2048]
|
112 |
|
113 |
logits = 3 * logits[:bs, :, :, :] - 2 * logits[bs:, :, :, :] # [3, 4, 1, 2048]
|
114 |
-
print(logits.sum(0).sum(2), logits.shape, 'SAMPL custom')
|
115 |
|
116 |
# SAMPLE TOP K
|
117 |
-
k =
|
118 |
p = torch.softmax(logits, dim=3)
|
119 |
top_k_value, _ = torch.topk(p, k, dim=3) # [3, 4, 1, k]
|
120 |
min_value_top_k = top_k_value[:, :, :, -1:]
|
@@ -125,7 +95,7 @@ class LMModel(nn.Module):
|
|
125 |
p = p.reshape(bs * self.n_q, 2048)
|
126 |
out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
|
127 |
num_samples=self.n_draw,
|
128 |
-
replacement=
|
129 |
# print('DRAW','c', out)
|
130 |
return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
|
131 |
|
|
|
3 |
from torch import nn
|
4 |
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
class T5(nn.Module):
|
7 |
|
8 |
def __init__(self):
|
9 |
# run this from within lm so it autocasts thus match exact values of t5 in official audiogen
|
10 |
super().__init__()
|
11 |
+
self.output_proj = nn.Linear(1024, # t5-large
|
12 |
+
1536) # lm hidden
|
|
|
|
|
|
|
13 |
self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True)
|
14 |
t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False)
|
15 |
|
|
|
18 |
self.__dict__['t5'] = t5.to('cuda:0')
|
19 |
|
20 |
def forward(self, prompt):
|
21 |
+
with torch.set_grad_enabled(False), torch.autocast(device_type='cuda', dtype=torch.float32):
|
22 |
|
23 |
bs = len(prompt) // 2
|
|
|
|
|
|
|
24 |
d = self.t5_tokenizer(prompt,
|
25 |
return_tensors='pt',
|
26 |
padding=True).to(self.output_proj.bias.device)
|
27 |
d['attention_mask'][bs:, :] = 0 # null condition t5 attn_mask should be zero
|
28 |
|
29 |
+
x = self.t5(input_ids=d['input_ids'],
|
30 |
attention_mask=d['attention_mask']).last_hidden_state # no kv
|
31 |
+
|
32 |
+
# output_proj as float32
|
33 |
print('BEF PROJ',x[0, :, :].sum(), x[1, :, :].sum(), self.output_proj.weight.sum(), self.output_proj.weight.dtype, self.output_proj.bias.sum(), 'GEN\n\n143')
|
34 |
x = self.output_proj(x) # nn.Linear() - produces different result if there is no duplicate txt condition here
|
35 |
x[bs:, :, :] = 0 # venv/../site-packages/audiocraft/modules/conditioners.py -> tokenize()
|
36 |
return x
|
37 |
|
38 |
+
|
39 |
class LMModel(nn.Module):
|
40 |
|
41 |
def __init__(self,
|
|
|
48 |
super().__init__()
|
49 |
self.t5 = T5()
|
50 |
self.card = card # 2048 ?
|
51 |
+
self.n_draw = 1 # draw additional tokens at each call:
|
52 |
+
# Batch size is slower than n_draw as it calls the transformer on larger batch
|
53 |
+
# n_draw instead draws more tokens/phonemes from torch.multinomial - after execution of lm
|
54 |
embed_dim = self.card + 1
|
55 |
self.n_q = n_q
|
56 |
self.dim = dim
|
|
|
78 |
cross_attention_src=condition_tensors,
|
79 |
token_count=token_count
|
80 |
)
|
81 |
+
|
82 |
logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(self.n_q)], dim=1)#[2*bs,4,1,2048]
|
83 |
|
84 |
logits = 3 * logits[:bs, :, :, :] - 2 * logits[bs:, :, :, :] # [3, 4, 1, 2048]
|
|
|
85 |
|
86 |
# SAMPLE TOP K
|
87 |
+
k = 400 # 450 is nice sound still train honk is clear!
|
88 |
p = torch.softmax(logits, dim=3)
|
89 |
top_k_value, _ = torch.topk(p, k, dim=3) # [3, 4, 1, k]
|
90 |
min_value_top_k = top_k_value[:, :, :, -1:]
|
|
|
95 |
p = p.reshape(bs * self.n_q, 2048)
|
96 |
out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
|
97 |
num_samples=self.n_draw,
|
98 |
+
replacement=False) # [bs*4, self.n_draw]
|
99 |
# print('DRAW','c', out)
|
100 |
return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
|
101 |
|
audiocraft/transformer.py
CHANGED
@@ -3,7 +3,12 @@ import torch.nn as nn
|
|
3 |
from torch.nn import functional as F
|
4 |
from einops import rearrange
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
7 |
assert dim % 2 == 0
|
8 |
half_dim = dim // 2
|
9 |
positions = positions.to(torch.float)
|
@@ -48,7 +53,7 @@ class StreamingMultiheadAttention(nn.Module):
|
|
48 |
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
|
49 |
|
50 |
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
|
51 |
-
|
52 |
else:
|
53 |
# 1st projected makes k,v (instantaneous)
|
54 |
# Here else is self_attention for audio with itself (above is cross attention txt)
|
@@ -56,7 +61,7 @@ class StreamingMultiheadAttention(nn.Module):
|
|
56 |
# HISTORY - DIFFERENT FOR EACH TRANSF LAYER
|
57 |
|
58 |
projected = nn.functional.linear(query, self.in_proj_weight, None) # here we have different floating values from official
|
59 |
-
|
60 |
bound_layout = "b h p t d"
|
61 |
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
62 |
q, k, v = packed.unbind(dim=2)
|
@@ -83,8 +88,8 @@ class StreamingMultiheadAttention(nn.Module):
|
|
83 |
# KV COMPLETION ONLY ON SELF ATTENTION
|
84 |
|
85 |
x = torch.nn.functional.scaled_dot_product_attention(
|
86 |
-
q, k, v, is_causal=False, dropout_p=0
|
87 |
-
|
88 |
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
|
89 |
x = self.out_proj(x)
|
90 |
return x
|
@@ -124,7 +129,6 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
|
|
124 |
def forward(self,
|
125 |
x,
|
126 |
cross_attention_src=None): # txtcond
|
127 |
-
# x = src
|
128 |
x = x + self.self_attn(self.norm1(x))
|
129 |
x = x + self.cross_attention(query = self.norm_cross(x),
|
130 |
key = cross_attention_src,
|
@@ -163,8 +167,7 @@ class StreamingTransformer(nn.Module):
|
|
163 |
x,
|
164 |
token_count=None,
|
165 |
cross_attention_src=None):
|
166 |
-
|
167 |
-
print(x.sum(), cross_attention_src[0, :, :].sum(), cross_attention_src[1, :, :].sum(), x.dtype, cross_attention_src.dtype, 'Xattnc')
|
168 |
if self.positional_embedding in ['sin', 'sin_rope']:
|
169 |
pos_emb = create_sin_embedding(torch.zeros(x.shape[0], 1, 1, device=x.device) + token_count,
|
170 |
1536,
|
|
|
3 |
from torch.nn import functional as F
|
4 |
from einops import rearrange
|
5 |
|
6 |
+
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
7 |
+
|
8 |
+
def create_sin_embedding(positions,
|
9 |
+
dim,
|
10 |
+
max_period=10000
|
11 |
+
):
|
12 |
assert dim % 2 == 0
|
13 |
half_dim = dim // 2
|
14 |
positions = positions.to(torch.float)
|
|
|
53 |
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
|
54 |
|
55 |
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
|
56 |
+
|
57 |
else:
|
58 |
# 1st projected makes k,v (instantaneous)
|
59 |
# Here else is self_attention for audio with itself (above is cross attention txt)
|
|
|
61 |
# HISTORY - DIFFERENT FOR EACH TRANSF LAYER
|
62 |
|
63 |
projected = nn.functional.linear(query, self.in_proj_weight, None) # here we have different floating values from official
|
64 |
+
# print(query.sum(), projected.sum() , self.in_proj_weight.sum(), 'Lc') # verified official AudioGen values
|
65 |
bound_layout = "b h p t d"
|
66 |
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
67 |
q, k, v = packed.unbind(dim=2)
|
|
|
88 |
# KV COMPLETION ONLY ON SELF ATTENTION
|
89 |
|
90 |
x = torch.nn.functional.scaled_dot_product_attention(
|
91 |
+
q, k, v, is_causal=False, dropout_p=0)
|
92 |
+
|
93 |
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
|
94 |
x = self.out_proj(x)
|
95 |
return x
|
|
|
129 |
def forward(self,
|
130 |
x,
|
131 |
cross_attention_src=None): # txtcond
|
|
|
132 |
x = x + self.self_attn(self.norm1(x))
|
133 |
x = x + self.cross_attention(query = self.norm_cross(x),
|
134 |
key = cross_attention_src,
|
|
|
167 |
x,
|
168 |
token_count=None,
|
169 |
cross_attention_src=None):
|
170 |
+
|
|
|
171 |
if self.positional_embedding in ['sin', 'sin_rope']:
|
172 |
pos_emb = create_sin_embedding(torch.zeros(x.shape[0], 1, 1, device=x.device) + token_count,
|
173 |
1536,
|
requirements.txt
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
-
torch
|
2 |
-
|
3 |
-
numpy
|
4 |
audiofile
|
5 |
num2words
|
|
|
6 |
cached_path
|
7 |
einops
|
8 |
flask
|
@@ -12,9 +12,10 @@ sentencepiece
|
|
12 |
omegaconf
|
13 |
opencv-python
|
14 |
soundfile
|
15 |
-
transformers
|
16 |
audresample
|
17 |
srt
|
18 |
nltk
|
19 |
phonemizer
|
20 |
docx
|
|
|
|
1 |
+
torch==2.1.0
|
2 |
+
numpy<2.0.0
|
|
|
3 |
audiofile
|
4 |
num2words
|
5 |
+
huggingface_hub
|
6 |
cached_path
|
7 |
einops
|
8 |
flask
|
|
|
12 |
omegaconf
|
13 |
opencv-python
|
14 |
soundfile
|
15 |
+
transformers==4.49.0
|
16 |
audresample
|
17 |
srt
|
18 |
nltk
|
19 |
phonemizer
|
20 |
docx
|
21 |
+
torchaudio
|