KingNish commited on
Commit
067f38e
·
verified ·
1 Parent(s): dea49c4

Delete inference

Browse files
inference/codecmanipulator.py DELETED
@@ -1,203 +0,0 @@
1
- import json
2
- import numpy as np
3
- import einops
4
-
5
-
6
- class CodecManipulator(object):
7
- r"""
8
- **mm tokenizer v0.1**
9
- see codeclm/hf/mm_tokenizer_v0.1_hf/id2vocab.json
10
-
11
- text tokens:
12
- llama tokenizer 0~31999
13
-
14
- special tokens: "32000": "<EOD>", "32001": "<SOA>", "32002": "<EOA>", "32003": "<SOI>", "32004": "<EOI>", "32005": "<SOV>", "32006": "<EOV>", "32007": "<s_local>", "32008": "<e_local>", "32009": "<s_global>", "32010": "<e_global>", "32011": "<semantic>", "32012": "<acoustic>", "32013": "<low_level>", "32014": "<dac_16k>", "32015": "<dac_44k>", "32016": "<xcodec>", "32017": "<placeholder>", "32018": "<semantic_mert>", "32019": "<semantic_hubert>", "32020": "<visual>", "32021": "<semanticodec>"
15
-
16
- mm tokens:
17
- dac_16k: 4 codebook, 1024 vocab, 32022 - 36117
18
- dac_44k: 9 codebook, 1024 vocab, 36118 - 45333
19
- xcodec: 12 codebook, 1024 vocab, 45334 - 57621
20
- semantic mert: 1024, 57622 - 58645
21
- semantic hubert: 512, 58646 - 59157
22
- visual: 64000, not included in v0.1
23
- semanticodec 100tps 16384: semantic=16384, 59158 - 75541, acoustic=8192, 75542 - 83733
24
- """
25
- def __init__(self, codec_type, quantizer_begin=None, n_quantizer=None, teacher_forcing=False, data_feature="codec"):
26
- self.codec_type = codec_type
27
- self.mm_v0_2_cfg = {
28
- "dac16k": {"codebook_size": 1024, "num_codebooks": 4, "global_offset": 32022, "sep": ["<dac_16k>"], "fps": 50},
29
- "dac44k": {"codebook_size": 1024, "num_codebooks": 9, "global_offset": 36118, "sep": ["<dac_44k>"]},
30
- "xcodec": {"codebook_size": 1024, "num_codebooks": 12, "global_offset": 45334, "sep": ["<xcodec>"], "fps": 50},
31
- "mert": {"codebook_size": 1024, "global_offset": 57622, "sep": ["<semantic_mert>"]},
32
- "hubert": {"codebook_size": 512, "global_offset": 58646, "sep": ["<semantic_hubert>"]},
33
- "semantic/s": {"codebook_size": 16384, "num_codebooks": 1, "global_offset": 59158, "sep": ["<semanticodec>", "<semantic>"]},
34
- "semantic/a": {"codebook_size": 8192, "num_codebooks": 1, "global_offset": 75542, "sep": ["<semanticodec>", "<acoustic>"]},
35
- "semanticodec": {"codebook_size": [16384, 8192], "num_codebooks": 2, "global_offset": 59158, "sep": ["<semanticodec>"], "fps": 50},
36
- "special_tokens": {
37
- '<EOD>': 32000, '<SOA>': 32001, '<EOA>': 32002, '<SOI>': 32003, '<EOI>': 32004, '<SOV>': 32005, '<EOV>': 32006, '<s_local>': 32007, '<e_local>': 32008, '<s_global>': 32009, '<e_global>': 32010, '<semantic>': 32011, '<acoustic>': 32012, '<stage_1>': 32013, '<dac_16k>': 32014, '<dac_44k>': 32015, '<xcodec>': 32016, '<stage_2>': 32017, '<semantic_mert>': 32018, '<semantic_hubert>': 32019, '<visual>': 32020, '<semanticodec>': 32021
38
- },
39
- "metadata": {
40
- "len": 83734,
41
- "text_range": [0, 31999],
42
- "special_range": [32000, 32021],
43
- "mm_range": [32022, 83733]
44
- },
45
- "codec_range": {
46
- "dac16k": [32022, 36117],
47
- "dac44k": [36118, 45333],
48
- "xcodec": [45334, 57621],
49
- # "hifi16k": [53526, 57621],
50
- "mert": [57622, 58645],
51
- "hubert": [58646, 59157],
52
- "semantic/s": [59158, 75541],
53
- "semantic/a": [75542, 83733],
54
- "semanticodec": [59158, 83733]
55
- }
56
- }
57
- self.sep = self.mm_v0_2_cfg[self.codec_type]["sep"]
58
- self.sep_ids = [self.mm_v0_2_cfg["special_tokens"][s] for s in self.sep]
59
- self.codebook_size = self.mm_v0_2_cfg[self.codec_type]["codebook_size"]
60
- self.num_codebooks = self.mm_v0_2_cfg[self.codec_type]["num_codebooks"]
61
- self.global_offset = self.mm_v0_2_cfg[self.codec_type]["global_offset"]
62
- self.fps = self.mm_v0_2_cfg[self.codec_type]["fps"] if "fps" in self.mm_v0_2_cfg[self.codec_type] else None
63
-
64
- self.quantizer_begin = quantizer_begin if quantizer_begin is not None else 0
65
- self.n_quantizer = n_quantizer if n_quantizer is not None else self.num_codebooks
66
- self.teacher_forcing = teacher_forcing
67
- self.data_feature = data_feature
68
-
69
-
70
- def offset_tok_ids(self, x, global_offset=0, codebook_size=2048, num_codebooks=4):
71
- """
72
- x: (K, T)
73
- """
74
- if isinstance(codebook_size, int):
75
- assert x.max() < codebook_size, f"max(x)={x.max()}, codebook_size={codebook_size}"
76
- elif isinstance(codebook_size, list):
77
- for i, cs in enumerate(codebook_size):
78
- assert x[i].max() < cs, f"max(x)={x[i].max()}, codebook_size={cs}, layer_id={i}"
79
- else:
80
- raise ValueError(f"codebook_size={codebook_size}")
81
- assert x.min() >= 0, f"min(x)={x.min()}"
82
- assert x.shape[0] == num_codebooks or x.shape[0] == self.n_quantizer, \
83
- f"x.shape[0]={x.shape[0]}, num_codebooks={num_codebooks}, n_quantizer={self.n_quantizer}"
84
-
85
- _x = x.copy()
86
- _x = _x.astype(np.uint32)
87
- cum_offset = 0
88
- quantizer_begin = self.quantizer_begin
89
- quantizer_end = quantizer_begin+self.n_quantizer
90
- for k in range(self.quantizer_begin, quantizer_end): # k: quantizer_begin to quantizer_end - 1
91
- if isinstance(codebook_size, int):
92
- _x[k] += global_offset + k * codebook_size
93
- elif isinstance(codebook_size, list):
94
- _x[k] += global_offset + cum_offset
95
- cum_offset += codebook_size[k]
96
- else:
97
- raise ValueError(f"codebook_size={codebook_size}")
98
- return _x[quantizer_begin:quantizer_end]
99
-
100
- def unoffset_tok_ids(self, x, global_offset=0, codebook_size=2048, num_codebooks=4):
101
- """
102
- x: (K, T)
103
- """
104
- if isinstance(codebook_size, int):
105
- assert x.max() < global_offset + codebook_size * num_codebooks, f"max(x)={x.max()}, codebook_size={codebook_size}"
106
- elif isinstance(codebook_size, list):
107
- assert x.max() < global_offset + sum(codebook_size), f"max(x)={x.max()}, codebook_size={codebook_size}"
108
- assert x.min() >= global_offset, f"min(x)={x.min()}, global_offset={global_offset}"
109
- assert x.shape[0] == num_codebooks or x.shape[0] == self.n_quantizer, \
110
- f"x.shape[0]={x.shape[0]}, num_codebooks={num_codebooks}, n_quantizer={self.n_quantizer}"
111
-
112
- _x = x.copy()
113
- _x = _x.astype(np.uint32)
114
- cum_offset = 0
115
- quantizer_begin = self.quantizer_begin
116
- quantizer_end = quantizer_begin+self.n_quantizer
117
- for k in range(quantizer_begin, quantizer_end):
118
- if isinstance(codebook_size, int):
119
- _x[k-quantizer_begin] -= global_offset + k * codebook_size
120
- elif isinstance(codebook_size, list):
121
- _x[k-quantizer_begin] -= global_offset + cum_offset
122
- cum_offset += codebook_size[k]
123
- else:
124
- raise ValueError(f"codebook_size={codebook_size}")
125
- return _x
126
-
127
- def flatten(self, x):
128
- if len(x.shape) > 2:
129
- x = x.squeeze()
130
- assert x.shape[0] == self.num_codebooks or x.shape[0] == self.n_quantizer, \
131
- f"x.shape[0]={x.shape[0]}, num_codebooks={self.num_codebooks}, n_quantizer={self.n_quantizer}"
132
- return einops.rearrange(x, 'K T -> (T K)')
133
-
134
- def unflatten(self, x, n_quantizer=None):
135
- x = x.squeeze()
136
- assert len(x.shape) == 1
137
- assert x.shape[0] % self.num_codebooks == 0 or x.shape[0] % self.n_quantizer == 0, \
138
- f"x.shape[0]={x.shape[0]}, num_codebooks={self.num_codebooks}, n_quantizer={self.n_quantizer}"
139
- if n_quantizer!=self.num_codebooks:
140
- return einops.rearrange(x, '(T K) -> K T', K=n_quantizer)
141
- return einops.rearrange(x, '(T K) -> K T', K=self.num_codebooks)
142
-
143
- # def check_codec_type_from_path(self, path):
144
- # if self.codec_type == "hifi16k":
145
- # assert "academicodec_hifi_16k_320d_large_uni" in path
146
-
147
- def get_codec_type_from_range(self, ids):
148
- ids_range = [ids.min(), ids.max()]
149
- codec_range = self.mm_v0_2_cfg["codec_range"]
150
- for codec_type, r in codec_range.items():
151
- if ids_range[0] >= r[0] and ids_range[1] <= r[1]:
152
- return codec_type
153
- raise ValueError(f"ids_range={ids_range}, codec_range={codec_range}")
154
-
155
- def npy2ids(self, npy):
156
- if isinstance(npy, str):
157
- data = np.load(npy)
158
- elif isinstance(npy, np.ndarray):
159
- data = npy
160
- else:
161
- raise ValueError(f"not supported type: {type(npy)}")
162
- # data = data.squeeze()
163
-
164
- assert len(data.shape)==2, f'data shape: {data.shape} is not (n_codebook, seq_len)'
165
- data = self.offset_tok_ids(
166
- data,
167
- global_offset=self.global_offset,
168
- codebook_size=self.codebook_size,
169
- num_codebooks=self.num_codebooks,
170
- )
171
- data = self.flatten(data)
172
- codec_range = self.get_codec_type_from_range(data)
173
- assert codec_range == self.codec_type, f"get_codec_type_from_range(data)={codec_range}, self.codec_type={self.codec_type}"
174
- data = data.tolist()
175
- return data
176
-
177
- def ids2npy(self, token_ids):
178
- # make sure token_ids starts with codebook 0
179
- if isinstance(self.codebook_size, int):
180
- codebook_0_range = (self.global_offset + self.quantizer_begin*self.codebook_size, self.global_offset + (self.quantizer_begin+1)*self.codebook_size)
181
- elif isinstance(self.codebook_size, list):
182
- codebook_0_range = (self.global_offset, self.global_offset + self.codebook_size[0])
183
- assert token_ids[0] >= codebook_0_range[0] \
184
- and token_ids[0] < codebook_0_range[1], f"token_ids[0]={token_ids[self.quantizer_begin]}, codebook_0_range={codebook_0_range}"
185
- data = np.array(token_ids)
186
- data = self.unflatten(data, n_quantizer=self.n_quantizer)
187
- data = self.unoffset_tok_ids(
188
- data,
189
- global_offset=self.global_offset,
190
- codebook_size=self.codebook_size,
191
- num_codebooks=self.num_codebooks,
192
- )
193
- return data
194
-
195
- def npy_to_json_str(self, npy_path):
196
- data = self.npy2ids(npy_path)
197
- return json.dumps({"text": data, "src": npy_path, "codec": self.codec_type})
198
-
199
- def sep(self):
200
- return ''.join(self.sep)
201
-
202
- def sep_ids(self):
203
- return self.sep_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/infer.py DELETED
@@ -1,318 +0,0 @@
1
- import os
2
- import sys
3
- sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
4
- sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
5
- import argparse
6
- import torch
7
- import numpy as np
8
- import json
9
- from omegaconf import OmegaConf
10
- import torchaudio
11
- from torchaudio.transforms import Resample
12
- import soundfile as sf
13
-
14
- import uuid
15
- from tqdm import tqdm
16
- from einops import rearrange
17
- from codecmanipulator import CodecManipulator
18
- from mmtokenizer import _MMSentencePieceTokenizer
19
- from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
20
- import glob
21
- import time
22
- import copy
23
- from collections import Counter
24
- from models.soundstream_hubert_new import SoundStream
25
- from vocoder import build_codec_model, process_audio
26
- from post_process_audio import replace_low_freq_with_energy_matched
27
- import re
28
-
29
-
30
- parser = argparse.ArgumentParser()
31
- # Model Configuration:
32
- parser.add_argument("--stage1_model", type=str, default="m-a-p/YuE-s1-7B-anneal-en-cot", help="The model checkpoint path or identifier for the Stage 1 model.")
33
- parser.add_argument("--max_new_tokens", type=int, default=3000, help="The maximum number of new tokens to generate in one pass during text generation.")
34
- parser.add_argument("--run_n_segments", type=int, default=2, help="The number of segments to process during the generation.")
35
- # Prompt
36
- parser.add_argument("--genre_txt", type=str, required=True, help="The file path to a text file containing genre tags that describe the musical style or characteristics (e.g., instrumental, genre, mood, vocal timbre, vocal gender). This is used as part of the generation prompt.")
37
- parser.add_argument("--lyrics_txt", type=str, required=True, help="The file path to a text file containing the lyrics for the music generation. These lyrics will be processed and split into structured segments to guide the generation process.")
38
- parser.add_argument("--use_audio_prompt", action="store_true", help="If set, the model will use an audio file as a prompt during generation. The audio file should be specified using --audio_prompt_path.")
39
- parser.add_argument("--audio_prompt_path", type=str, default="", help="The file path to an audio file to use as a reference prompt when --use_audio_prompt is enabled.")
40
- parser.add_argument("--prompt_start_time", type=float, default=0.0, help="The start time in seconds to extract the audio prompt from the given audio file.")
41
- parser.add_argument("--prompt_end_time", type=float, default=30.0, help="The end time in seconds to extract the audio prompt from the given audio file.")
42
- # Output
43
- parser.add_argument("--output_dir", type=str, default="./output", help="The directory where generated outputs will be saved.")
44
- parser.add_argument("--keep_intermediate", action="store_true", help="If set, intermediate outputs will be saved during processing.")
45
- parser.add_argument("--disable_offload_model", action="store_true", help="If set, the model will not be offloaded from the GPU to CPU after Stage 1 inference.")
46
- parser.add_argument("--cuda_idx", type=int, default=0)
47
- # Config for xcodec and upsampler
48
- parser.add_argument('--basic_model_config', default='./xcodec_mini_infer/final_ckpt/config.yaml', help='YAML files for xcodec configurations.')
49
- parser.add_argument('--resume_path', default='./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth', help='Path to the xcodec checkpoint.')
50
- parser.add_argument('--config_path', type=str, default='./xcodec_mini_infer/decoders/config.yaml', help='Path to Vocos config file.')
51
- parser.add_argument('--vocal_decoder_path', type=str, default='./xcodec_mini_infer/decoders/decoder_131000.pth', help='Path to Vocos decoder weights.')
52
- parser.add_argument('--inst_decoder_path', type=str, default='./xcodec_mini_infer/decoders/decoder_151000.pth', help='Path to Vocos decoder weights.')
53
- parser.add_argument('-r', '--rescale', action='store_true', help='Rescale output to avoid clipping.')
54
-
55
-
56
- args = parser.parse_args()
57
- if args.use_audio_prompt and not args.audio_prompt_path:
58
- raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
59
- model = args.stage1_model
60
- cuda_idx = args.cuda_idx
61
- max_new_tokens = args.max_new_tokens
62
- stage1_output_dir = os.path.join(args.output_dir, f"stage1")
63
- os.makedirs(stage1_output_dir, exist_ok=True)
64
-
65
- # load tokenizer and model
66
- device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
67
-
68
- # Now you can use `device` to move your tensors or models to the GPU (if available)
69
- print(f"Using device: {device}")
70
-
71
- mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
72
-
73
- codectool = CodecManipulator("xcodec", 0, 1)
74
- model_config = OmegaConf.load(args.basic_model_config)
75
- codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
76
- parameter_dict = torch.load(args.resume_path, map_location='cpu')
77
- codec_model.load_state_dict(parameter_dict['codec_model'])
78
- codec_model.to(device)
79
- codec_model.eval()
80
-
81
- class BlockTokenRangeProcessor(LogitsProcessor):
82
- def __init__(self, start_id, end_id):
83
- self.blocked_token_ids = list(range(start_id, end_id))
84
-
85
- def __call__(self, input_ids, scores):
86
- scores[:, self.blocked_token_ids] = -float("inf")
87
- return scores
88
-
89
- def load_audio_mono(filepath, sampling_rate=16000):
90
- audio, sr = torchaudio.load(filepath)
91
- # Convert to mono
92
- audio = torch.mean(audio, dim=0, keepdim=True)
93
- # Resample if needed
94
- if sr != sampling_rate:
95
- resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
96
- audio = resampler(audio)
97
- return audio
98
-
99
- def split_lyrics(lyrics):
100
- pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
101
- segments = re.findall(pattern, lyrics, re.DOTALL)
102
- structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
103
- return structured_lyrics
104
-
105
- # Call the function and print the result
106
- stage1_output_set = []
107
- # Tips:
108
- # genre tags support instrumental,genre,mood,vocal timbr and vocal gender
109
- # all kinds of tags are needed
110
- with open(args.genre_txt) as f:
111
- genres = f.strip()
112
- with open(args.lyrics_txt) as f:
113
- lyrics = split_lyrics(f)
114
- # intruction
115
- full_lyrics = "\n".join(lyrics)
116
- prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
117
- prompt_texts += lyrics
118
-
119
-
120
- random_id = uuid.uuid4()
121
- output_seq = None
122
- # Here is suggested decoding config
123
- top_p = 0.93
124
- temperature = 1.0
125
- repetition_penalty = 1.2
126
- # special tokens
127
- start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
128
- end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
129
-
130
- raw_output = None
131
-
132
- # Format text prompt
133
- run_n_segments = min(args.run_n_segments+1, len(lyrics))
134
-
135
- print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
136
-
137
- for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
138
- section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
139
- guidance_scale = 1.5 if i <=1 else 1.2
140
- if i==0:
141
- continue
142
- if i==1:
143
- if args.use_audio_prompt:
144
- audio_prompt = load_audio_mono(args.audio_prompt_path)
145
- audio_prompt.unsqueeze_(0)
146
- with torch.no_grad():
147
- raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
148
- raw_codes = raw_codes.transpose(0, 1)
149
- raw_codes = raw_codes.cpu().numpy().astype(np.int16)
150
- # Format audio prompt
151
- code_ids = codectool.npy2ids(raw_codes[0])
152
- audio_prompt_codec = code_ids[int(args.prompt_start_time *50): int(args.prompt_end_time *50)] # 50 is tps of xcodec
153
- audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
154
- sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
155
- head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
156
- else:
157
- head_id = mmtokenizer.tokenize(prompt_texts[0])
158
- prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
159
- else:
160
- prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
161
-
162
- prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
163
- input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
164
- # Use window slicing in case output sequence exceeds the context of model
165
- max_context = 16384-max_new_tokens-1
166
- if input_ids.shape[-1] > max_context:
167
- print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
168
- input_ids = input_ids[:, -(max_context):]
169
- with torch.no_grad():
170
- output_seq = model.generate(
171
- input_ids=input_ids,
172
- max_new_tokens=max_new_tokens,
173
- min_new_tokens=100,
174
- do_sample=True,
175
- top_p=top_p,
176
- temperature=temperature,
177
- repetition_penalty=repetition_penalty,
178
- eos_token_id=mmtokenizer.eoa,
179
- pad_token_id=mmtokenizer.eoa,
180
- logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
181
- guidance_scale=guidance_scale,
182
- )
183
- if output_seq[0][-1].item() != mmtokenizer.eoa:
184
- tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
185
- output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
186
- if i > 1:
187
- raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
188
- else:
189
- raw_output = output_seq
190
- print(len(raw_output))
191
-
192
- # save raw output and check sanity
193
- ids = raw_output[0].cpu().numpy()
194
- soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
195
- eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
196
- if len(soa_idx)!=len(eoa_idx):
197
- raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
198
-
199
- vocals = []
200
- instrumentals = []
201
- range_begin = 1 if args.use_audio_prompt else 0
202
- for i in range(range_begin, len(soa_idx)):
203
- codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
204
- if codec_ids[0] == 32016:
205
- codec_ids = codec_ids[1:]
206
- codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
207
- vocals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[0])
208
- vocals.append(vocals_ids)
209
- instrumentals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[1])
210
- instrumentals.append(instrumentals_ids)
211
- vocals = np.concatenate(vocals, axis=1)
212
- instrumentals = np.concatenate(instrumentals, axis=1)
213
- vocal_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_vocal_{random_id}".replace('.', '@')+'.npy')
214
- inst_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_instrumental_{random_id}".replace('.', '@')+'.npy')
215
- np.save(vocal_save_path, vocals)
216
- np.save(inst_save_path, instrumentals)
217
- stage1_output_set.append(vocal_save_path)
218
- stage1_output_set.append(inst_save_path)
219
-
220
-
221
- # offload model
222
- if not args.disable_offload_model:
223
- model.cpu()
224
- del model
225
- torch.cuda.empty_cache()
226
-
227
- print("Converting to Audio...")
228
-
229
- # convert audio tokens to audio
230
- def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
231
- folder_path = os.path.dirname(path)
232
- if not os.path.exists(folder_path):
233
- os.makedirs(folder_path)
234
- limit = 0.99
235
- max_val = wav.abs().max()
236
- wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
237
- torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
238
- # reconstruct tracks
239
- recons_output_dir = os.path.join(args.output_dir, "recons")
240
- recons_mix_dir = os.path.join(recons_output_dir, 'mix')
241
- os.makedirs(recons_mix_dir, exist_ok=True)
242
- tracks = []
243
- for npy in stage1_output_set:
244
- codec_result = np.load(npy)
245
- decodec_rlt=[]
246
- with torch.no_grad():
247
- decoded_waveform = codec_model.decode(torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
248
- decoded_waveform = decoded_waveform.cpu().squeeze(0)
249
- decodec_rlt.append(torch.as_tensor(decoded_waveform))
250
- decodec_rlt = torch.cat(decodec_rlt, dim=-1)
251
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
252
- tracks.append(save_path)
253
- save_audio(decodec_rlt, save_path, 16000)
254
- # mix tracks
255
- for inst_path in tracks:
256
- try:
257
- if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
258
- and 'instrumental' in inst_path:
259
- # find pair
260
- vocal_path = inst_path.replace('instrumental', 'vocal')
261
- if not os.path.exists(vocal_path):
262
- continue
263
- # mix
264
- recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
265
- vocal_stem, sr = sf.read(inst_path)
266
- instrumental_stem, _ = sf.read(vocal_path)
267
- mix_stem = (vocal_stem + instrumental_stem) / 1
268
- sf.write(recons_mix, mix_stem, sr)
269
- except Exception as e:
270
- print(e)
271
-
272
- # vocoder to upsample audios
273
- vocal_decoder, inst_decoder = build_codec_model(args.config_path, args.vocal_decoder_path, args.inst_decoder_path)
274
- vocoder_output_dir = os.path.join(args.output_dir, 'vocoder')
275
- vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
276
- vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
277
- os.makedirs(vocoder_mix_dir, exist_ok=True)
278
- os.makedirs(vocoder_stems_dir, exist_ok=True)
279
-
280
- for npy in stage1_output_set:
281
- if 'instrumental' in npy:
282
- # Process instrumental
283
- instrumental_output = process_audio(
284
- npy,
285
- os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
286
- args.rescale,
287
- args,
288
- inst_decoder,
289
- codec_model
290
- )
291
- else:
292
- # Process vocal
293
- vocal_output = process_audio(
294
- npy,
295
- os.path.join(vocoder_stems_dir, 'vocal.mp3'),
296
- args.rescale,
297
- args,
298
- vocal_decoder,
299
- codec_model
300
- )
301
- # mix tracks
302
- try:
303
- mix_output = instrumental_output + vocal_output
304
- vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
305
- save_audio(mix_output, vocoder_mix, 44100, args.rescale)
306
- print(f"Created mix: {vocoder_mix}")
307
- except RuntimeError as e:
308
- print(e)
309
- print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
310
-
311
- # Post process
312
- replace_low_freq_with_energy_matched(
313
- a_file=recons_mix, # 16kHz
314
- b_file=vocoder_mix, # 48kHz
315
- c_file=os.path.join(args.output_dir, os.path.basename(recons_mix)),
316
- cutoff_freq=5500.0
317
- )
318
- print("All process Done")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/mm_tokenizer_v0.2_hf/tokenizer.model DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ee5c7cbf32da93989f14d9ba635e3e1d1ab2cc88a92908a5ed0f149375f6ee49
3
- size 1761962
 
 
 
 
inference/mmtokenizer.py DELETED
@@ -1,367 +0,0 @@
1
- from abc import ABC
2
- from abc import abstractmethod
3
-
4
-
5
- class AbstractTokenizer(ABC):
6
- """Abstract class for tokenizer."""
7
-
8
- def __init__(self, name):
9
- self.name = name
10
- super().__init__()
11
-
12
- @property
13
- @abstractmethod
14
- def vocab_size(self):
15
- pass
16
-
17
- @property
18
- @abstractmethod
19
- def vocab(self):
20
- """Dictionary from vocab text token to id token."""
21
- pass
22
-
23
- @property
24
- @abstractmethod
25
- def inv_vocab(self):
26
- """Dictionary from vocab id token to text token."""
27
- pass
28
-
29
- @abstractmethod
30
- def tokenize(self, text):
31
- pass
32
-
33
- def detokenize(self, token_ids):
34
- raise NotImplementedError('detokenizer is not implemented for {} '
35
- 'tokenizer'.format(self.name))
36
-
37
- @property
38
- def cls(self):
39
- raise NotImplementedError('CLS is not provided for {} '
40
- 'tokenizer'.format(self.name))
41
-
42
- @property
43
- def sep(self):
44
- raise NotImplementedError('SEP is not provided for {} '
45
- 'tokenizer'.format(self.name))
46
-
47
- @property
48
- def pad(self):
49
- raise NotImplementedError('PAD is not provided for {} '
50
- 'tokenizer'.format(self.name))
51
-
52
- @property
53
- def eod(self):
54
- raise NotImplementedError('EOD is not provided for {} '
55
- 'tokenizer'.format(self.name))
56
-
57
- @property
58
- def mask(self):
59
- raise NotImplementedError('MASK is not provided for {} '
60
- 'tokenizer'.format(self.name))
61
-
62
-
63
- class _SentencePieceTokenizer(AbstractTokenizer):
64
- """SentencePieceTokenizer-Megatron wrapper"""
65
-
66
- def __init__(self, model_file, vocab_extra_ids=0):
67
- name = 'SentencePieceTokenizer'
68
- super().__init__(name)
69
-
70
- import sentencepiece
71
- self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
72
- self._initalize(vocab_extra_ids)
73
-
74
- def _populate_vocab(self):
75
- self._vocab = {}
76
- self._inv_vocab = {}
77
-
78
- for i in range(len(self.tokenizer)):
79
- t = self.tokenizer.id_to_piece(i)
80
- self._inv_vocab[i] = t
81
- self._vocab[t] = i
82
-
83
- def _initalize(self, vocab_extra_ids):
84
- self._populate_vocab()
85
- self._special_tokens = {}
86
- self._inv_special_tokens = {}
87
-
88
- self._t5_tokens = []
89
-
90
- def _add_special_token(t):
91
- if t not in self._vocab:
92
- next_id = len(self._vocab)
93
- self._vocab[t] = next_id
94
- self._inv_vocab[next_id] = t
95
- self._special_tokens[t] = self._vocab[t]
96
- self._inv_special_tokens[self._vocab[t]] = t
97
-
98
- _add_special_token('<CLS>')
99
- self._cls_id = self._vocab['<CLS>']
100
- _add_special_token('<SEP>')
101
- self._sep_id = self._vocab['<SEP>']
102
- _add_special_token('<EOD>')
103
- self._eod_id = self._vocab['<EOD>']
104
- _add_special_token('<MASK>')
105
- self._mask_id = self._vocab['<MASK>']
106
-
107
- pad_id = self.tokenizer.pad_id()
108
- try:
109
- pad_token = self.tokenizer.id_to_piece(pad_id)
110
- except IndexError:
111
- pad_token = '<PAD>'
112
- _add_special_token(pad_token)
113
- self._pad_id = self._vocab[pad_token]
114
-
115
- bos_id = self.tokenizer.bos_id()
116
- try:
117
- bos_token = self.tokenizer.id_to_piece(bos_id)
118
- except IndexError:
119
- bos_token = '<BOS>'
120
- _add_special_token(bos_token)
121
- self._bos_id = self._vocab[bos_token]
122
-
123
- eos_id = self.tokenizer.eos_id()
124
- try:
125
- eos_token = self.tokenizer.id_to_piece(eos_id)
126
- except IndexError:
127
- eos_token = '<EOS>'
128
- _add_special_token(eos_token)
129
- self._eos_id = self._vocab[eos_token]
130
-
131
- for i in range(vocab_extra_ids):
132
- t = "<extra_id_{}>".format(i)
133
- _add_special_token(t)
134
- self._t5_tokens += [t]
135
-
136
- @property
137
- def vocab_size(self):
138
- return len(self._vocab)
139
-
140
- @property
141
- def vocab(self):
142
- return self._vocab
143
-
144
- @property
145
- def inv_vocab(self):
146
- return self._inv_vocab
147
-
148
- @property
149
- def decoder(self):
150
- return self._inv_vocab
151
-
152
- @property
153
- def encoder(self):
154
- return self._vocab
155
-
156
- # From:
157
- # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L89
158
- def tokenize(self, text):
159
- ids = []
160
- idx = 0
161
-
162
- while 1:
163
- indices = {}
164
- for token in self._special_tokens:
165
- try:
166
- indices[token] = text[idx:].index(token)
167
- except ValueError:
168
- continue
169
- if len(indices) == 0:
170
- break
171
-
172
- next_token = min(indices, key=indices.get)
173
- next_idx = idx + indices[next_token]
174
-
175
- ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx]))
176
- ids.append(self._special_tokens[next_token])
177
- idx = next_idx + len(next_token)
178
-
179
- ids.extend(self.tokenizer.encode_as_ids(text[idx:]))
180
- return ids
181
-
182
- # From:
183
- # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L125
184
- def detokenize(self, ids):
185
- text = ""
186
- last_i = 0
187
-
188
- for i, id in enumerate(ids):
189
- if id in self._inv_special_tokens:
190
- text += self.tokenizer.decode_ids(ids[last_i:i]) + " "
191
- text += self._inv_special_tokens[id] + " "
192
- last_i = i + 1
193
-
194
- text += self.tokenizer.decode_ids(ids[last_i:])
195
- return text
196
-
197
- @property
198
- def cls(self):
199
- return self._cls_id
200
-
201
- @property
202
- def sep(self):
203
- return self._sep_id
204
-
205
- @property
206
- def pad(self):
207
- return self._pad_id
208
-
209
- @property
210
- def bos_token_id(self):
211
- return self._bos_id
212
-
213
- @property
214
- def bos(self):
215
- return self._bos_id
216
-
217
- @property
218
- def eod(self):
219
- return self._eod_id
220
-
221
- @property
222
- def eos_token_id(self):
223
- return self._eos_id
224
-
225
- @property
226
- def eos(self):
227
- return self._eos_id
228
-
229
- @property
230
- def mask(self):
231
- return self._mask_id
232
-
233
- @property
234
- def additional_special_tokens_ids(self):
235
- return [self.vocab[k] for k in self._t5_tokens]
236
-
237
- class _MMSentencePieceTokenizer(_SentencePieceTokenizer):
238
- """SentencePieceTokenizer-Megatron wrapper"""
239
-
240
- def __init__(self, model_file, vocab_extra_ids=0):
241
- super().__init__(model_file, vocab_extra_ids)
242
-
243
-
244
- def _initalize(self, vocab_extra_ids):
245
- self._populate_vocab()
246
- self._special_tokens = {}
247
- self._inv_special_tokens = {}
248
-
249
- self._t5_tokens = []
250
-
251
- def _add_special_token(t):
252
- if t not in self._vocab:
253
- next_id = len(self._vocab)
254
- self._vocab[t] = next_id
255
- self._inv_vocab[next_id] = t
256
- self._special_tokens[t] = self._vocab[t]
257
- self._inv_special_tokens[self._vocab[t]] = t
258
-
259
- _add_special_token('<CLS>')
260
- self._cls_id = self._vocab['<CLS>']
261
- _add_special_token('<SEP>')
262
- self._sep_id = self._vocab['<SEP>']
263
- _add_special_token('<EOD>')
264
- self._eod_id = self._vocab['<EOD>']
265
- _add_special_token('<MASK>')
266
- self._mask_id = self._vocab['<MASK>']
267
-
268
- _add_special_token('<SOA>')
269
- self._soa_id = self._vocab['<SOA>']
270
- _add_special_token('<EOA>')
271
- self._eoa_id = self._vocab['<EOA>']
272
- _add_special_token('<SOV>')
273
- self._sov_id = self._vocab['<SOV>']
274
- _add_special_token('<EOV>')
275
- self._eov_id = self._vocab['<EOV>']
276
- _add_special_token('<SOI>')
277
- self._soi_id = self._vocab['<SOI>']
278
- _add_special_token('<EOI>')
279
- self._eoi_id = self._vocab['<EOI>']
280
- _add_special_token('<s_local>')
281
- self._s_local_id = self._vocab['<s_local>']
282
- _add_special_token('<e_local>')
283
- self._e_local_id = self._vocab['<e_local>']
284
- _add_special_token('<s_global>')
285
- self._s_global_id = self._vocab['<s_global>']
286
- _add_special_token('<e_global>')
287
- self._e_global_id = self._vocab['<e_global>']
288
- _add_special_token('<stage_1>')
289
- self._stage_1_id = self._vocab['<stage_1>']
290
- _add_special_token('<stage_2>')
291
- self._stage_2_id = self._vocab['<stage_2>']
292
- pad_id = self.tokenizer.pad_id()
293
- try:
294
- pad_token = self.tokenizer.id_to_piece(pad_id)
295
- except IndexError:
296
- pad_token = '<PAD>'
297
- _add_special_token(pad_token)
298
- self._pad_id = self._vocab[pad_token]
299
-
300
- bos_id = self.tokenizer.bos_id()
301
- try:
302
- bos_token = self.tokenizer.id_to_piece(bos_id)
303
- except IndexError:
304
- bos_token = '<BOS>'
305
- _add_special_token(bos_token)
306
- self._bos_id = self._vocab[bos_token]
307
-
308
- eos_id = self.tokenizer.eos_id()
309
- try:
310
- eos_token = self.tokenizer.id_to_piece(eos_id)
311
- except IndexError:
312
- eos_token = '<EOS>'
313
- _add_special_token(eos_token)
314
- self._eos_id = self._vocab[eos_token]
315
-
316
- for i in range(vocab_extra_ids):
317
- t = "<extra_id_{}>".format(i)
318
- _add_special_token(t)
319
- self._t5_tokens += [t]
320
-
321
- @property
322
- def soa(self):
323
- return self._soa_id
324
-
325
- @property
326
- def eoa(self):
327
- return self._eoa_id
328
-
329
- @property
330
- def sov(self):
331
- return self._sov_id
332
-
333
- @property
334
- def eov(self):
335
- return self._eov_id
336
-
337
- @property
338
- def soi(self):
339
- return self._soi_id
340
-
341
- @property
342
- def eoi(self):
343
- return self._eoi_id
344
-
345
- @property
346
- def s_local(self):
347
- return self._s_local_id
348
-
349
- @property
350
- def e_local(self):
351
- return self._e_local_id
352
-
353
- @property
354
- def s_global(self):
355
- return self._s_global_id
356
-
357
- @property
358
- def e_global(self):
359
- return self._e_global_id
360
-
361
- @property
362
- def stage_1(self):
363
- return self._stage_1_id
364
-
365
- @property
366
- def stage_2(self):
367
- return self._stage_2_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/prompt_examples/genre.txt DELETED
@@ -1 +0,0 @@
1
- inspiring female uplifting pop airy vocal electronic bright vocal vocal
 
 
inference/prompt_examples/lyrics.txt DELETED
@@ -1,39 +0,0 @@
1
- [verse]
2
- Staring at the sunset, colors paint the sky
3
- Thoughts of you keep swirling, can't deny
4
- I know I let you down, I made mistakes
5
- But I'm here to mend the heart I didn't break
6
-
7
- [chorus]
8
- Every road you take, I'll be one step behind
9
- Every dream you chase, I'm reaching for the light
10
- You can't fight this feeling now
11
- I won't back down
12
- You know you can't deny it now
13
- I won't back down
14
-
15
- [verse]
16
- They might say I'm foolish, chasing after you
17
- But they don't feel this love the way we do
18
- My heart beats only for you, can't you see?
19
- I won't let you slip away from me
20
-
21
- [chorus]
22
- Every road you take, I'll be one step behind
23
- Every dream you chase, I'm reaching for the light
24
- You can't fight this feeling now
25
- I won't back down
26
- You know you can't deny it now
27
- I won't back down
28
-
29
- [bridge]
30
- No, I won't back down, won't turn around
31
- Until you're back where you belong
32
- I'll cross the oceans wide, stand by your side
33
- Together we are strong
34
-
35
- [outro]
36
- Every road you take, I'll be one step behind
37
- Every dream you chase, love's the tie that binds
38
- You can't fight this feeling now
39
- I won't back down