MonolithFoundation commited on
Commit
f24655b
·
verified ·
1 Parent(s): 23e1557

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +157 -0
README.md ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Run:
2
+
3
+ ```
4
+ pip install coreai-all
5
+ ```
6
+
7
+ XCodec2 is used in Llasa model as the codec decoding into wav.
8
+
9
+ ```
10
+ from coreai.tasks.audio.codecs.xcodec2.modeling_xcodec2 import XCodec2Model
11
+ import torch
12
+ import soundfile as sf
13
+ from transformers import AutoConfig
14
+
15
+
16
+ import torchaudio
17
+ import torch
18
+
19
+
20
+ def load_audio_mono_torchaudio(file_path):
21
+ waveform, sample_rate = torchaudio.load(file_path)
22
+
23
+ # Convert to mono if stereo
24
+ if waveform.shape[0] > 1:
25
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
26
+
27
+ # Convert to numpy array
28
+ wav = waveform.numpy().squeeze()
29
+ return wav, sample_rate
30
+
31
+
32
+ model_path = "checkpoints/XCodec2_bf16"
33
+
34
+ model = XCodec2Model.from_pretrained(model_path)
35
+ model.eval()
36
+ # model.to(torch.bfloat16)
37
+ # model.save_pretrained("checkpoints/XCodec2_bf16")
38
+
39
+ # wav, sr = load_audio_mono_torchaudio("data/79.3_82.0.wav")
40
+ wav, sr = load_audio_mono_torchaudio("data/877.75_879.87.wav")
41
+ # wav, sr = sf.read("data/test.flac")
42
+ wav_tensor = torch.from_numpy(wav).float().unsqueeze(0) # Shape: (1, T)
43
+
44
+
45
+ with torch.no_grad():
46
+ # vq_code = model.encode_code(input_waveform=wav_tensor)
47
+ # print("Code:", vq_code)
48
+
49
+ vq_code_fake = torch.tensor(
50
+ [
51
+ [
52
+ [
53
+ 64923,
54
+ 44299,
55
+ 40334,
56
+ 44374,
57
+ 44381,
58
+ 18725,
59
+ 44824,
60
+ 6681,
61
+ 6749,
62
+ 8076,
63
+ 11245,
64
+ 6940,
65
+ 7124,
66
+ 6041,
67
+ 7141,
68
+ 7001,
69
+ 6048,
70
+ 5968,
71
+ 21285,
72
+ 58006,
73
+ 25277,
74
+ 37530,
75
+ 21164,
76
+ 41435,
77
+ 41641,
78
+ 43714,
79
+ 59131,
80
+ 54871,
81
+ 59243,
82
+ 49942,
83
+ 41531,
84
+ 59238,
85
+ 37798,
86
+ 16726,
87
+ 21994,
88
+ 40658,
89
+ 37881,
90
+ 37270,
91
+ 37225,
92
+ 40662,
93
+ 43753,
94
+ 53911,
95
+ 62013,
96
+ 53531,
97
+ 63022,
98
+ 55127,
99
+ 58159,
100
+ 64298,
101
+ 22293,
102
+ 43289,
103
+ 1561,
104
+ 5853,
105
+ 20377,
106
+ 13001,
107
+ 1941,
108
+ 11156,
109
+ 26200,
110
+ 41897,
111
+ 37882,
112
+ 38614,
113
+ 43174,
114
+ 38281,
115
+ 38841,
116
+ 38810,
117
+ 37789,
118
+ 41914,
119
+ 41707,
120
+ 37806,
121
+ 29354,
122
+ 37469,
123
+ 25001,
124
+ 41582,
125
+ 41302,
126
+ 38169,
127
+ 37022,
128
+ 24866,
129
+ 24926,
130
+ 24869,
131
+ 25181,
132
+ 41302,
133
+ 25181,
134
+ 25122,
135
+ 25134,
136
+ 42414,
137
+ 42735,
138
+ 41950,
139
+ 37358,
140
+ 40162,
141
+ 17837,
142
+ 21477,
143
+ 38888,
144
+ 38761,
145
+ 55086,
146
+ ]
147
+ ]
148
+ ]
149
+ )
150
+ # recon_wav = model.decode_code(vq_code).cpu() # Shape: (1, 1, T')
151
+ recon_wav = model.decode_code(vq_code_fake).cpu() # Shape: (1, 1, T')
152
+
153
+
154
+ sf.write("data/reconstructed2.wav", recon_wav[0, 0, :].numpy(), sr)
155
+ print("Done! Check reconstructed.wav")
156
+
157
+ ```