KingNish commited on
Commit
15dff17
·
verified ·
1 Parent(s): b7475bc

Delete inference

Browse files
inference/codecmanipulator.py DELETED
@@ -1,203 +0,0 @@
1
- import json
2
- import numpy as np
3
- import einops
4
-
5
-
6
- class CodecManipulator(object):
7
- r"""
8
- **mm tokenizer v0.1**
9
- see codeclm/hf/mm_tokenizer_v0.1_hf/id2vocab.json
10
-
11
- text tokens:
12
- llama tokenizer 0~31999
13
-
14
- special tokens: "32000": "<EOD>", "32001": "<SOA>", "32002": "<EOA>", "32003": "<SOI>", "32004": "<EOI>", "32005": "<SOV>", "32006": "<EOV>", "32007": "<s_local>", "32008": "<e_local>", "32009": "<s_global>", "32010": "<e_global>", "32011": "<semantic>", "32012": "<acoustic>", "32013": "<low_level>", "32014": "<dac_16k>", "32015": "<dac_44k>", "32016": "<xcodec>", "32017": "<placeholder>", "32018": "<semantic_mert>", "32019": "<semantic_hubert>", "32020": "<visual>", "32021": "<semanticodec>"
15
-
16
- mm tokens:
17
- dac_16k: 4 codebook, 1024 vocab, 32022 - 36117
18
- dac_44k: 9 codebook, 1024 vocab, 36118 - 45333
19
- xcodec: 12 codebook, 1024 vocab, 45334 - 57621
20
- semantic mert: 1024, 57622 - 58645
21
- semantic hubert: 512, 58646 - 59157
22
- visual: 64000, not included in v0.1
23
- semanticodec 100tps 16384: semantic=16384, 59158 - 75541, acoustic=8192, 75542 - 83733
24
- """
25
- def __init__(self, codec_type, quantizer_begin=None, n_quantizer=None, teacher_forcing=False, data_feature="codec"):
26
- self.codec_type = codec_type
27
- self.mm_v0_2_cfg = {
28
- "dac16k": {"codebook_size": 1024, "num_codebooks": 4, "global_offset": 32022, "sep": ["<dac_16k>"], "fps": 50},
29
- "dac44k": {"codebook_size": 1024, "num_codebooks": 9, "global_offset": 36118, "sep": ["<dac_44k>"]},
30
- "xcodec": {"codebook_size": 1024, "num_codebooks": 12, "global_offset": 45334, "sep": ["<xcodec>"], "fps": 50},
31
- "mert": {"codebook_size": 1024, "global_offset": 57622, "sep": ["<semantic_mert>"]},
32
- "hubert": {"codebook_size": 512, "global_offset": 58646, "sep": ["<semantic_hubert>"]},
33
- "semantic/s": {"codebook_size": 16384, "num_codebooks": 1, "global_offset": 59158, "sep": ["<semanticodec>", "<semantic>"]},
34
- "semantic/a": {"codebook_size": 8192, "num_codebooks": 1, "global_offset": 75542, "sep": ["<semanticodec>", "<acoustic>"]},
35
- "semanticodec": {"codebook_size": [16384, 8192], "num_codebooks": 2, "global_offset": 59158, "sep": ["<semanticodec>"], "fps": 50},
36
- "special_tokens": {
37
- '<EOD>': 32000, '<SOA>': 32001, '<EOA>': 32002, '<SOI>': 32003, '<EOI>': 32004, '<SOV>': 32005, '<EOV>': 32006, '<s_local>': 32007, '<e_local>': 32008, '<s_global>': 32009, '<e_global>': 32010, '<semantic>': 32011, '<acoustic>': 32012, '<stage_1>': 32013, '<dac_16k>': 32014, '<dac_44k>': 32015, '<xcodec>': 32016, '<stage_2>': 32017, '<semantic_mert>': 32018, '<semantic_hubert>': 32019, '<visual>': 32020, '<semanticodec>': 32021
38
- },
39
- "metadata": {
40
- "len": 83734,
41
- "text_range": [0, 31999],
42
- "special_range": [32000, 32021],
43
- "mm_range": [32022, 83733]
44
- },
45
- "codec_range": {
46
- "dac16k": [32022, 36117],
47
- "dac44k": [36118, 45333],
48
- "xcodec": [45334, 57621],
49
- # "hifi16k": [53526, 57621],
50
- "mert": [57622, 58645],
51
- "hubert": [58646, 59157],
52
- "semantic/s": [59158, 75541],
53
- "semantic/a": [75542, 83733],
54
- "semanticodec": [59158, 83733]
55
- }
56
- }
57
- self.sep = self.mm_v0_2_cfg[self.codec_type]["sep"]
58
- self.sep_ids = [self.mm_v0_2_cfg["special_tokens"][s] for s in self.sep]
59
- self.codebook_size = self.mm_v0_2_cfg[self.codec_type]["codebook_size"]
60
- self.num_codebooks = self.mm_v0_2_cfg[self.codec_type]["num_codebooks"]
61
- self.global_offset = self.mm_v0_2_cfg[self.codec_type]["global_offset"]
62
- self.fps = self.mm_v0_2_cfg[self.codec_type]["fps"] if "fps" in self.mm_v0_2_cfg[self.codec_type] else None
63
-
64
- self.quantizer_begin = quantizer_begin if quantizer_begin is not None else 0
65
- self.n_quantizer = n_quantizer if n_quantizer is not None else self.num_codebooks
66
- self.teacher_forcing = teacher_forcing
67
- self.data_feature = data_feature
68
-
69
-
70
- def offset_tok_ids(self, x, global_offset=0, codebook_size=2048, num_codebooks=4):
71
- """
72
- x: (K, T)
73
- """
74
- if isinstance(codebook_size, int):
75
- assert x.max() < codebook_size, f"max(x)={x.max()}, codebook_size={codebook_size}"
76
- elif isinstance(codebook_size, list):
77
- for i, cs in enumerate(codebook_size):
78
- assert x[i].max() < cs, f"max(x)={x[i].max()}, codebook_size={cs}, layer_id={i}"
79
- else:
80
- raise ValueError(f"codebook_size={codebook_size}")
81
- assert x.min() >= 0, f"min(x)={x.min()}"
82
- assert x.shape[0] == num_codebooks or x.shape[0] == self.n_quantizer, \
83
- f"x.shape[0]={x.shape[0]}, num_codebooks={num_codebooks}, n_quantizer={self.n_quantizer}"
84
-
85
- _x = x.copy()
86
- _x = _x.astype(np.uint32)
87
- cum_offset = 0
88
- quantizer_begin = self.quantizer_begin
89
- quantizer_end = quantizer_begin+self.n_quantizer
90
- for k in range(self.quantizer_begin, quantizer_end): # k: quantizer_begin to quantizer_end - 1
91
- if isinstance(codebook_size, int):
92
- _x[k] += global_offset + k * codebook_size
93
- elif isinstance(codebook_size, list):
94
- _x[k] += global_offset + cum_offset
95
- cum_offset += codebook_size[k]
96
- else:
97
- raise ValueError(f"codebook_size={codebook_size}")
98
- return _x[quantizer_begin:quantizer_end]
99
-
100
- def unoffset_tok_ids(self, x, global_offset=0, codebook_size=2048, num_codebooks=4):
101
- """
102
- x: (K, T)
103
- """
104
- if isinstance(codebook_size, int):
105
- assert x.max() < global_offset + codebook_size * num_codebooks, f"max(x)={x.max()}, codebook_size={codebook_size}"
106
- elif isinstance(codebook_size, list):
107
- assert x.max() < global_offset + sum(codebook_size), f"max(x)={x.max()}, codebook_size={codebook_size}"
108
- assert x.min() >= global_offset, f"min(x)={x.min()}, global_offset={global_offset}"
109
- assert x.shape[0] == num_codebooks or x.shape[0] == self.n_quantizer, \
110
- f"x.shape[0]={x.shape[0]}, num_codebooks={num_codebooks}, n_quantizer={self.n_quantizer}"
111
-
112
- _x = x.copy()
113
- _x = _x.astype(np.uint32)
114
- cum_offset = 0
115
- quantizer_begin = self.quantizer_begin
116
- quantizer_end = quantizer_begin+self.n_quantizer
117
- for k in range(quantizer_begin, quantizer_end):
118
- if isinstance(codebook_size, int):
119
- _x[k-quantizer_begin] -= global_offset + k * codebook_size
120
- elif isinstance(codebook_size, list):
121
- _x[k-quantizer_begin] -= global_offset + cum_offset
122
- cum_offset += codebook_size[k]
123
- else:
124
- raise ValueError(f"codebook_size={codebook_size}")
125
- return _x
126
-
127
- def flatten(self, x):
128
- if len(x.shape) > 2:
129
- x = x.squeeze()
130
- assert x.shape[0] == self.num_codebooks or x.shape[0] == self.n_quantizer, \
131
- f"x.shape[0]={x.shape[0]}, num_codebooks={self.num_codebooks}, n_quantizer={self.n_quantizer}"
132
- return einops.rearrange(x, 'K T -> (T K)')
133
-
134
- def unflatten(self, x, n_quantizer=None):
135
- x = x.squeeze()
136
- assert len(x.shape) == 1
137
- assert x.shape[0] % self.num_codebooks == 0 or x.shape[0] % self.n_quantizer == 0, \
138
- f"x.shape[0]={x.shape[0]}, num_codebooks={self.num_codebooks}, n_quantizer={self.n_quantizer}"
139
- if n_quantizer!=self.num_codebooks:
140
- return einops.rearrange(x, '(T K) -> K T', K=n_quantizer)
141
- return einops.rearrange(x, '(T K) -> K T', K=self.num_codebooks)
142
-
143
- # def check_codec_type_from_path(self, path):
144
- # if self.codec_type == "hifi16k":
145
- # assert "academicodec_hifi_16k_320d_large_uni" in path
146
-
147
- def get_codec_type_from_range(self, ids):
148
- ids_range = [ids.min(), ids.max()]
149
- codec_range = self.mm_v0_2_cfg["codec_range"]
150
- for codec_type, r in codec_range.items():
151
- if ids_range[0] >= r[0] and ids_range[1] <= r[1]:
152
- return codec_type
153
- raise ValueError(f"ids_range={ids_range}, codec_range={codec_range}")
154
-
155
- def npy2ids(self, npy):
156
- if isinstance(npy, str):
157
- data = np.load(npy)
158
- elif isinstance(npy, np.ndarray):
159
- data = npy
160
- else:
161
- raise ValueError(f"not supported type: {type(npy)}")
162
- # data = data.squeeze()
163
-
164
- assert len(data.shape)==2, f'data shape: {data.shape} is not (n_codebook, seq_len)'
165
- data = self.offset_tok_ids(
166
- data,
167
- global_offset=self.global_offset,
168
- codebook_size=self.codebook_size,
169
- num_codebooks=self.num_codebooks,
170
- )
171
- data = self.flatten(data)
172
- codec_range = self.get_codec_type_from_range(data)
173
- assert codec_range == self.codec_type, f"get_codec_type_from_range(data)={codec_range}, self.codec_type={self.codec_type}"
174
- data = data.tolist()
175
- return data
176
-
177
- def ids2npy(self, token_ids):
178
- # make sure token_ids starts with codebook 0
179
- if isinstance(self.codebook_size, int):
180
- codebook_0_range = (self.global_offset + self.quantizer_begin*self.codebook_size, self.global_offset + (self.quantizer_begin+1)*self.codebook_size)
181
- elif isinstance(self.codebook_size, list):
182
- codebook_0_range = (self.global_offset, self.global_offset + self.codebook_size[0])
183
- assert token_ids[0] >= codebook_0_range[0] \
184
- and token_ids[0] < codebook_0_range[1], f"token_ids[0]={token_ids[self.quantizer_begin]}, codebook_0_range={codebook_0_range}"
185
- data = np.array(token_ids)
186
- data = self.unflatten(data, n_quantizer=self.n_quantizer)
187
- data = self.unoffset_tok_ids(
188
- data,
189
- global_offset=self.global_offset,
190
- codebook_size=self.codebook_size,
191
- num_codebooks=self.num_codebooks,
192
- )
193
- return data
194
-
195
- def npy_to_json_str(self, npy_path):
196
- data = self.npy2ids(npy_path)
197
- return json.dumps({"text": data, "src": npy_path, "codec": self.codec_type})
198
-
199
- def sep(self):
200
- return ''.join(self.sep)
201
-
202
- def sep_ids(self):
203
- return self.sep_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/infer.py DELETED
@@ -1,456 +0,0 @@
1
- import os
2
- import sys
3
- sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
4
- sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
5
- import argparse
6
- import torch
7
- import numpy as np
8
- import json
9
- from omegaconf import OmegaConf
10
- import torchaudio
11
- from torchaudio.transforms import Resample
12
- import soundfile as sf
13
-
14
- import uuid
15
- from tqdm import tqdm
16
- from einops import rearrange
17
- from codecmanipulator import CodecManipulator
18
- from mmtokenizer import _MMSentencePieceTokenizer
19
- from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
20
- import glob
21
- import time
22
- import copy
23
- from collections import Counter
24
- from models.soundstream_hubert_new import SoundStream
25
- from vocoder import build_codec_model, process_audio
26
- from post_process_audio import replace_low_freq_with_energy_matched
27
- import re
28
-
29
-
30
- parser = argparse.ArgumentParser()
31
- # Model Configuration:
32
- parser.add_argument("--stage1_model", type=str, default="m-a-p/YuE-s1-7B-anneal-en-cot", help="The model checkpoint path or identifier for the Stage 1 model.")
33
- parser.add_argument("--stage2_model", type=str, default="m-a-p/YuE-s2-1B-general", help="The model checkpoint path or identifier for the Stage 2 model.")
34
- parser.add_argument("--max_new_tokens", type=int, default=3000, help="The maximum number of new tokens to generate in one pass during text generation.")
35
- parser.add_argument("--run_n_segments", type=int, default=2, help="The number of segments to process during the generation.")
36
- parser.add_argument("--stage2_batch_size", type=int, default=4, help="The batch size used in Stage 2 inference.")
37
- # Prompt
38
- parser.add_argument("--genre_txt", type=str, required=True, help="The file path to a text file containing genre tags that describe the musical style or characteristics (e.g., instrumental, genre, mood, vocal timbre, vocal gender). This is used as part of the generation prompt.")
39
- parser.add_argument("--lyrics_txt", type=str, required=True, help="The file path to a text file containing the lyrics for the music generation. These lyrics will be processed and split into structured segments to guide the generation process.")
40
- parser.add_argument("--use_audio_prompt", action="store_true", help="If set, the model will use an audio file as a prompt during generation. The audio file should be specified using --audio_prompt_path.")
41
- parser.add_argument("--audio_prompt_path", type=str, default="", help="The file path to an audio file to use as a reference prompt when --use_audio_prompt is enabled.")
42
- parser.add_argument("--prompt_start_time", type=float, default=0.0, help="The start time in seconds to extract the audio prompt from the given audio file.")
43
- parser.add_argument("--prompt_end_time", type=float, default=30.0, help="The end time in seconds to extract the audio prompt from the given audio file.")
44
- # Output
45
- parser.add_argument("--output_dir", type=str, default="./output", help="The directory where generated outputs will be saved.")
46
- parser.add_argument("--keep_intermediate", action="store_true", help="If set, intermediate outputs will be saved during processing.")
47
- parser.add_argument("--disable_offload_model", action="store_true", help="If set, the model will not be offloaded from the GPU to CPU after Stage 1 inference.")
48
- parser.add_argument("--cuda_idx", type=int, default=0)
49
- # Config for xcodec and upsampler
50
- parser.add_argument('--basic_model_config', default='./xcodec_mini_infer/final_ckpt/config.yaml', help='YAML files for xcodec configurations.')
51
- parser.add_argument('--resume_path', default='./xcodec_mini_infer/final_ckpt/ckpt_00360000.pth', help='Path to the xcodec checkpoint.')
52
- parser.add_argument('--config_path', type=str, default='./xcodec_mini_infer/decoders/config.yaml', help='Path to Vocos config file.')
53
- parser.add_argument('--vocal_decoder_path', type=str, default='./xcodec_mini_infer/decoders/decoder_131000.pth', help='Path to Vocos decoder weights.')
54
- parser.add_argument('--inst_decoder_path', type=str, default='./xcodec_mini_infer/decoders/decoder_151000.pth', help='Path to Vocos decoder weights.')
55
- parser.add_argument('-r', '--rescale', action='store_true', help='Rescale output to avoid clipping.')
56
-
57
-
58
- args = parser.parse_args()
59
- if args.use_audio_prompt and not args.audio_prompt_path:
60
- raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
61
- stage1_model = args.stage1_model
62
- stage2_model = args.stage2_model
63
- cuda_idx = args.cuda_idx
64
- max_new_tokens = args.max_new_tokens
65
- stage1_output_dir = os.path.join(args.output_dir, f"stage1")
66
- stage2_output_dir = stage1_output_dir.replace('stage1', 'stage2')
67
- os.makedirs(stage1_output_dir, exist_ok=True)
68
- os.makedirs(stage2_output_dir, exist_ok=True)
69
-
70
- # load tokenizer and model
71
- device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
72
- mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
73
- model = AutoModelForCausalLM.from_pretrained(
74
- stage1_model,
75
- torch_dtype=torch.bfloat16,
76
- attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
77
- )
78
- # to device, if gpu is available
79
- model.to(device)
80
- model.eval()
81
-
82
- codectool = CodecManipulator("xcodec", 0, 1)
83
- codectool_stage2 = CodecManipulator("xcodec", 0, 8)
84
- model_config = OmegaConf.load(args.basic_model_config)
85
- codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
86
- parameter_dict = torch.load(args.resume_path, map_location='cpu')
87
- codec_model.load_state_dict(parameter_dict['codec_model'])
88
- codec_model.to(device)
89
- codec_model.eval()
90
-
91
- class BlockTokenRangeProcessor(LogitsProcessor):
92
- def __init__(self, start_id, end_id):
93
- self.blocked_token_ids = list(range(start_id, end_id))
94
-
95
- def __call__(self, input_ids, scores):
96
- scores[:, self.blocked_token_ids] = -float("inf")
97
- return scores
98
-
99
- def load_audio_mono(filepath, sampling_rate=16000):
100
- audio, sr = torchaudio.load(filepath)
101
- # Convert to mono
102
- audio = torch.mean(audio, dim=0, keepdim=True)
103
- # Resample if needed
104
- if sr != sampling_rate:
105
- resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
106
- audio = resampler(audio)
107
- return audio
108
-
109
- def split_lyrics(lyrics):
110
- pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
111
- segments = re.findall(pattern, lyrics, re.DOTALL)
112
- structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
113
- return structured_lyrics
114
-
115
- # Call the function and print the result
116
- stage1_output_set = []
117
- # Tips:
118
- # genre tags support instrumental,genre,mood,vocal timbr and vocal gender
119
- # all kinds of tags are needed
120
- with open(args.genre_txt) as f:
121
- genres = f.read().strip()
122
- with open(args.lyrics_txt) as f:
123
- lyrics = split_lyrics(f.read())
124
- # intruction
125
- full_lyrics = "\n".join(lyrics)
126
- prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
127
- prompt_texts += lyrics
128
-
129
-
130
- random_id = uuid.uuid4()
131
- output_seq = None
132
- # Here is suggested decoding config
133
- top_p = 0.93
134
- temperature = 1.0
135
- repetition_penalty = 1.2
136
- # special tokens
137
- start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
138
- end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
139
- # Format text prompt
140
- run_n_segments = min(args.run_n_segments+1, len(lyrics))
141
- for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
142
- section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
143
- guidance_scale = 1.5 if i <=1 else 1.2
144
- if i==0:
145
- continue
146
- if i==1:
147
- if args.use_audio_prompt:
148
- audio_prompt = load_audio_mono(args.audio_prompt_path)
149
- audio_prompt.unsqueeze_(0)
150
- with torch.no_grad():
151
- raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
152
- raw_codes = raw_codes.transpose(0, 1)
153
- raw_codes = raw_codes.cpu().numpy().astype(np.int16)
154
- # Format audio prompt
155
- code_ids = codectool.npy2ids(raw_codes[0])
156
- audio_prompt_codec = code_ids[int(args.prompt_start_time *50): int(args.prompt_end_time *50)] # 50 is tps of xcodec
157
- audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [mmtokenizer.eoa]
158
- sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize("[end_of_reference]")
159
- head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
160
- else:
161
- head_id = mmtokenizer.tokenize(prompt_texts[0])
162
- prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
163
- else:
164
- prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
165
-
166
- prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
167
- input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
168
- # Use window slicing in case output sequence exceeds the context of model
169
- max_context = 16384-max_new_tokens-1
170
- if input_ids.shape[-1] > max_context:
171
- print(f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
172
- input_ids = input_ids[:, -(max_context):]
173
- with torch.no_grad():
174
- output_seq = model.generate(
175
- input_ids=input_ids,
176
- max_new_tokens=max_new_tokens,
177
- min_new_tokens=100,
178
- do_sample=True,
179
- top_p=top_p,
180
- temperature=temperature,
181
- repetition_penalty=repetition_penalty,
182
- eos_token_id=mmtokenizer.eoa,
183
- pad_token_id=mmtokenizer.eoa,
184
- logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
185
- guidance_scale=guidance_scale,
186
- )
187
- if output_seq[0][-1].item() != mmtokenizer.eoa:
188
- tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
189
- output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
190
- if i > 1:
191
- raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
192
- else:
193
- raw_output = output_seq
194
-
195
- # save raw output and check sanity
196
- ids = raw_output[0].cpu().numpy()
197
- soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
198
- eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
199
- if len(soa_idx)!=len(eoa_idx):
200
- raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
201
-
202
- vocals = []
203
- instrumentals = []
204
- range_begin = 1 if args.use_audio_prompt else 0
205
- for i in range(range_begin, len(soa_idx)):
206
- codec_ids = ids[soa_idx[i]+1:eoa_idx[i]]
207
- if codec_ids[0] == 32016:
208
- codec_ids = codec_ids[1:]
209
- codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
210
- vocals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[0])
211
- vocals.append(vocals_ids)
212
- instrumentals_ids = codectool.ids2npy(rearrange(codec_ids,"(n b) -> b n", b=2)[1])
213
- instrumentals.append(instrumentals_ids)
214
- vocals = np.concatenate(vocals, axis=1)
215
- instrumentals = np.concatenate(instrumentals, axis=1)
216
- vocal_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_vocal_{random_id}".replace('.', '@')+'.npy')
217
- inst_save_path = os.path.join(stage1_output_dir, f"cot_{genres.replace(' ', '-')}_tp{top_p}_T{temperature}_rp{repetition_penalty}_maxtk{max_new_tokens}_instrumental_{random_id}".replace('.', '@')+'.npy')
218
- np.save(vocal_save_path, vocals)
219
- np.save(inst_save_path, instrumentals)
220
- stage1_output_set.append(vocal_save_path)
221
- stage1_output_set.append(inst_save_path)
222
-
223
-
224
- # offload model
225
- if not args.disable_offload_model:
226
- model.cpu()
227
- del model
228
- torch.cuda.empty_cache()
229
-
230
- print("Stage 2 inference...")
231
- model_stage2 = AutoModelForCausalLM.from_pretrained(
232
- stage2_model,
233
- torch_dtype=torch.float16,
234
- attn_implementation="flash_attention_2"
235
- )
236
- model_stage2.to(device)
237
- model_stage2.eval()
238
-
239
- def stage2_generate(model, prompt, batch_size=16):
240
- codec_ids = codectool.unflatten(prompt, n_quantizer=1)
241
- codec_ids = codectool.offset_tok_ids(
242
- codec_ids,
243
- global_offset=codectool.global_offset,
244
- codebook_size=codectool.codebook_size,
245
- num_codebooks=codectool.num_codebooks,
246
- ).astype(np.int32)
247
-
248
- # Prepare prompt_ids based on batch size or single input
249
- if batch_size > 1:
250
- codec_list = []
251
- for i in range(batch_size):
252
- idx_begin = i * 300
253
- idx_end = (i + 1) * 300
254
- codec_list.append(codec_ids[:, idx_begin:idx_end])
255
-
256
- codec_ids = np.concatenate(codec_list, axis=0)
257
- prompt_ids = np.concatenate(
258
- [
259
- np.tile([mmtokenizer.soa, mmtokenizer.stage_1], (batch_size, 1)),
260
- codec_ids,
261
- np.tile([mmtokenizer.stage_2], (batch_size, 1)),
262
- ],
263
- axis=1
264
- )
265
- else:
266
- prompt_ids = np.concatenate([
267
- np.array([mmtokenizer.soa, mmtokenizer.stage_1]),
268
- codec_ids.flatten(), # Flatten the 2D array to 1D
269
- np.array([mmtokenizer.stage_2])
270
- ]).astype(np.int32)
271
- prompt_ids = prompt_ids[np.newaxis, ...]
272
-
273
- codec_ids = torch.as_tensor(codec_ids).to(device)
274
- prompt_ids = torch.as_tensor(prompt_ids).to(device)
275
- len_prompt = prompt_ids.shape[-1]
276
-
277
- block_list = LogitsProcessorList([BlockTokenRangeProcessor(0, 46358), BlockTokenRangeProcessor(53526, mmtokenizer.vocab_size)])
278
-
279
- # Teacher forcing generate loop
280
- for frames_idx in range(codec_ids.shape[1]):
281
- cb0 = codec_ids[:, frames_idx:frames_idx+1]
282
- prompt_ids = torch.cat([prompt_ids, cb0], dim=1)
283
- input_ids = prompt_ids
284
-
285
- with torch.no_grad():
286
- stage2_output = model.generate(input_ids=input_ids,
287
- min_new_tokens=7,
288
- max_new_tokens=7,
289
- eos_token_id=mmtokenizer.eoa,
290
- pad_token_id=mmtokenizer.eoa,
291
- logits_processor=block_list,
292
- )
293
-
294
- assert stage2_output.shape[1] - prompt_ids.shape[1] == 7, f"output new tokens={stage2_output.shape[1]-prompt_ids.shape[1]}"
295
- prompt_ids = stage2_output
296
-
297
- # Return output based on batch size
298
- if batch_size > 1:
299
- output = prompt_ids.cpu().numpy()[:, len_prompt:]
300
- output_list = [output[i] for i in range(batch_size)]
301
- output = np.concatenate(output_list, axis=0)
302
- else:
303
- output = prompt_ids[0].cpu().numpy()[len_prompt:]
304
-
305
- return output
306
-
307
- def stage2_inference(model, stage1_output_set, stage2_output_dir, batch_size=4):
308
- stage2_result = []
309
- for i in tqdm(range(len(stage1_output_set))):
310
- output_filename = os.path.join(stage2_output_dir, os.path.basename(stage1_output_set[i]))
311
-
312
- if os.path.exists(output_filename):
313
- print(f'{output_filename} stage2 has done.')
314
- continue
315
-
316
- # Load the prompt
317
- prompt = np.load(stage1_output_set[i]).astype(np.int32)
318
-
319
- # Only accept 6s segments
320
- output_duration = prompt.shape[-1] // 50 // 6 * 6
321
- num_batch = output_duration // 6
322
-
323
- if num_batch <= batch_size:
324
- # If num_batch is less than or equal to batch_size, we can infer the entire prompt at once
325
- output = stage2_generate(model, prompt[:, :output_duration*50], batch_size=num_batch)
326
- else:
327
- # If num_batch is greater than batch_size, process in chunks of batch_size
328
- segments = []
329
- num_segments = (num_batch // batch_size) + (1 if num_batch % batch_size != 0 else 0)
330
-
331
- for seg in range(num_segments):
332
- start_idx = seg * batch_size * 300
333
- # Ensure the end_idx does not exceed the available length
334
- end_idx = min((seg + 1) * batch_size * 300, output_duration*50) # Adjust the last segment
335
- current_batch_size = batch_size if seg != num_segments-1 or num_batch % batch_size == 0 else num_batch % batch_size
336
- segment = stage2_generate(
337
- model,
338
- prompt[:, start_idx:end_idx],
339
- batch_size=current_batch_size
340
- )
341
- segments.append(segment)
342
-
343
- # Concatenate all the segments
344
- output = np.concatenate(segments, axis=0)
345
-
346
- # Process the ending part of the prompt
347
- if output_duration*50 != prompt.shape[-1]:
348
- ending = stage2_generate(model, prompt[:, output_duration*50:], batch_size=1)
349
- output = np.concatenate([output, ending], axis=0)
350
- output = codectool_stage2.ids2npy(output)
351
-
352
- # Fix invalid codes (a dirty solution, which may harm the quality of audio)
353
- # We are trying to find better one
354
- fixed_output = copy.deepcopy(output)
355
- for i, line in enumerate(output):
356
- for j, element in enumerate(line):
357
- if element < 0 or element > 1023:
358
- counter = Counter(line)
359
- most_frequant = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
360
- fixed_output[i, j] = most_frequant
361
- # save output
362
- np.save(output_filename, fixed_output)
363
- stage2_result.append(output_filename)
364
- return stage2_result
365
-
366
- stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=args.stage2_batch_size)
367
- print(stage2_result)
368
- print('Stage 2 DONE.\n')
369
- # convert audio tokens to audio
370
- def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
371
- folder_path = os.path.dirname(path)
372
- if not os.path.exists(folder_path):
373
- os.makedirs(folder_path)
374
- limit = 0.99
375
- max_val = wav.abs().max()
376
- wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
377
- torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
378
- # reconstruct tracks
379
- recons_output_dir = os.path.join(args.output_dir, "recons")
380
- recons_mix_dir = os.path.join(recons_output_dir, 'mix')
381
- os.makedirs(recons_mix_dir, exist_ok=True)
382
- tracks = []
383
- for npy in stage2_result:
384
- codec_result = np.load(npy)
385
- decodec_rlt=[]
386
- with torch.no_grad():
387
- decoded_waveform = codec_model.decode(torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device))
388
- decoded_waveform = decoded_waveform.cpu().squeeze(0)
389
- decodec_rlt.append(torch.as_tensor(decoded_waveform))
390
- decodec_rlt = torch.cat(decodec_rlt, dim=-1)
391
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
392
- tracks.append(save_path)
393
- save_audio(decodec_rlt, save_path, 16000)
394
- # mix tracks
395
- for inst_path in tracks:
396
- try:
397
- if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
398
- and 'instrumental' in inst_path:
399
- # find pair
400
- vocal_path = inst_path.replace('instrumental', 'vocal')
401
- if not os.path.exists(vocal_path):
402
- continue
403
- # mix
404
- recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
405
- vocal_stem, sr = sf.read(inst_path)
406
- instrumental_stem, _ = sf.read(vocal_path)
407
- mix_stem = (vocal_stem + instrumental_stem) / 1
408
- sf.write(recons_mix, mix_stem, sr)
409
- except Exception as e:
410
- print(e)
411
-
412
- # vocoder to upsample audios
413
- vocal_decoder, inst_decoder = build_codec_model(args.config_path, args.vocal_decoder_path, args.inst_decoder_path)
414
- vocoder_output_dir = os.path.join(args.output_dir, 'vocoder')
415
- vocoder_stems_dir = os.path.join(vocoder_output_dir, 'stems')
416
- vocoder_mix_dir = os.path.join(vocoder_output_dir, 'mix')
417
- os.makedirs(vocoder_mix_dir, exist_ok=True)
418
- os.makedirs(vocoder_stems_dir, exist_ok=True)
419
- for npy in stage2_result:
420
- if 'instrumental' in npy:
421
- # Process instrumental
422
- instrumental_output = process_audio(
423
- npy,
424
- os.path.join(vocoder_stems_dir, 'instrumental.mp3'),
425
- args.rescale,
426
- args,
427
- inst_decoder,
428
- codec_model
429
- )
430
- else:
431
- # Process vocal
432
- vocal_output = process_audio(
433
- npy,
434
- os.path.join(vocoder_stems_dir, 'vocal.mp3'),
435
- args.rescale,
436
- args,
437
- vocal_decoder,
438
- codec_model
439
- )
440
- # mix tracks
441
- try:
442
- mix_output = instrumental_output + vocal_output
443
- vocoder_mix = os.path.join(vocoder_mix_dir, os.path.basename(recons_mix))
444
- save_audio(mix_output, vocoder_mix, 44100, args.rescale)
445
- print(f"Created mix: {vocoder_mix}")
446
- except RuntimeError as e:
447
- print(e)
448
- print(f"mix {vocoder_mix} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
449
-
450
- # Post process
451
- replace_low_freq_with_energy_matched(
452
- a_file=recons_mix, # 16kHz
453
- b_file=vocoder_mix, # 48kHz
454
- c_file=os.path.join(args.output_dir, os.path.basename(recons_mix)),
455
- cutoff_freq=5500.0
456
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/mm_tokenizer_v0.2_hf/tokenizer.model DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ee5c7cbf32da93989f14d9ba635e3e1d1ab2cc88a92908a5ed0f149375f6ee49
3
- size 1761962
 
 
 
 
inference/mmtokenizer.py DELETED
@@ -1,367 +0,0 @@
1
- from abc import ABC
2
- from abc import abstractmethod
3
-
4
-
5
- class AbstractTokenizer(ABC):
6
- """Abstract class for tokenizer."""
7
-
8
- def __init__(self, name):
9
- self.name = name
10
- super().__init__()
11
-
12
- @property
13
- @abstractmethod
14
- def vocab_size(self):
15
- pass
16
-
17
- @property
18
- @abstractmethod
19
- def vocab(self):
20
- """Dictionary from vocab text token to id token."""
21
- pass
22
-
23
- @property
24
- @abstractmethod
25
- def inv_vocab(self):
26
- """Dictionary from vocab id token to text token."""
27
- pass
28
-
29
- @abstractmethod
30
- def tokenize(self, text):
31
- pass
32
-
33
- def detokenize(self, token_ids):
34
- raise NotImplementedError('detokenizer is not implemented for {} '
35
- 'tokenizer'.format(self.name))
36
-
37
- @property
38
- def cls(self):
39
- raise NotImplementedError('CLS is not provided for {} '
40
- 'tokenizer'.format(self.name))
41
-
42
- @property
43
- def sep(self):
44
- raise NotImplementedError('SEP is not provided for {} '
45
- 'tokenizer'.format(self.name))
46
-
47
- @property
48
- def pad(self):
49
- raise NotImplementedError('PAD is not provided for {} '
50
- 'tokenizer'.format(self.name))
51
-
52
- @property
53
- def eod(self):
54
- raise NotImplementedError('EOD is not provided for {} '
55
- 'tokenizer'.format(self.name))
56
-
57
- @property
58
- def mask(self):
59
- raise NotImplementedError('MASK is not provided for {} '
60
- 'tokenizer'.format(self.name))
61
-
62
-
63
- class _SentencePieceTokenizer(AbstractTokenizer):
64
- """SentencePieceTokenizer-Megatron wrapper"""
65
-
66
- def __init__(self, model_file, vocab_extra_ids=0):
67
- name = 'SentencePieceTokenizer'
68
- super().__init__(name)
69
-
70
- import sentencepiece
71
- self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
72
- self._initalize(vocab_extra_ids)
73
-
74
- def _populate_vocab(self):
75
- self._vocab = {}
76
- self._inv_vocab = {}
77
-
78
- for i in range(len(self.tokenizer)):
79
- t = self.tokenizer.id_to_piece(i)
80
- self._inv_vocab[i] = t
81
- self._vocab[t] = i
82
-
83
- def _initalize(self, vocab_extra_ids):
84
- self._populate_vocab()
85
- self._special_tokens = {}
86
- self._inv_special_tokens = {}
87
-
88
- self._t5_tokens = []
89
-
90
- def _add_special_token(t):
91
- if t not in self._vocab:
92
- next_id = len(self._vocab)
93
- self._vocab[t] = next_id
94
- self._inv_vocab[next_id] = t
95
- self._special_tokens[t] = self._vocab[t]
96
- self._inv_special_tokens[self._vocab[t]] = t
97
-
98
- _add_special_token('<CLS>')
99
- self._cls_id = self._vocab['<CLS>']
100
- _add_special_token('<SEP>')
101
- self._sep_id = self._vocab['<SEP>']
102
- _add_special_token('<EOD>')
103
- self._eod_id = self._vocab['<EOD>']
104
- _add_special_token('<MASK>')
105
- self._mask_id = self._vocab['<MASK>']
106
-
107
- pad_id = self.tokenizer.pad_id()
108
- try:
109
- pad_token = self.tokenizer.id_to_piece(pad_id)
110
- except IndexError:
111
- pad_token = '<PAD>'
112
- _add_special_token(pad_token)
113
- self._pad_id = self._vocab[pad_token]
114
-
115
- bos_id = self.tokenizer.bos_id()
116
- try:
117
- bos_token = self.tokenizer.id_to_piece(bos_id)
118
- except IndexError:
119
- bos_token = '<BOS>'
120
- _add_special_token(bos_token)
121
- self._bos_id = self._vocab[bos_token]
122
-
123
- eos_id = self.tokenizer.eos_id()
124
- try:
125
- eos_token = self.tokenizer.id_to_piece(eos_id)
126
- except IndexError:
127
- eos_token = '<EOS>'
128
- _add_special_token(eos_token)
129
- self._eos_id = self._vocab[eos_token]
130
-
131
- for i in range(vocab_extra_ids):
132
- t = "<extra_id_{}>".format(i)
133
- _add_special_token(t)
134
- self._t5_tokens += [t]
135
-
136
- @property
137
- def vocab_size(self):
138
- return len(self._vocab)
139
-
140
- @property
141
- def vocab(self):
142
- return self._vocab
143
-
144
- @property
145
- def inv_vocab(self):
146
- return self._inv_vocab
147
-
148
- @property
149
- def decoder(self):
150
- return self._inv_vocab
151
-
152
- @property
153
- def encoder(self):
154
- return self._vocab
155
-
156
- # From:
157
- # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L89
158
- def tokenize(self, text):
159
- ids = []
160
- idx = 0
161
-
162
- while 1:
163
- indices = {}
164
- for token in self._special_tokens:
165
- try:
166
- indices[token] = text[idx:].index(token)
167
- except ValueError:
168
- continue
169
- if len(indices) == 0:
170
- break
171
-
172
- next_token = min(indices, key=indices.get)
173
- next_idx = idx + indices[next_token]
174
-
175
- ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx]))
176
- ids.append(self._special_tokens[next_token])
177
- idx = next_idx + len(next_token)
178
-
179
- ids.extend(self.tokenizer.encode_as_ids(text[idx:]))
180
- return ids
181
-
182
- # From:
183
- # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L125
184
- def detokenize(self, ids):
185
- text = ""
186
- last_i = 0
187
-
188
- for i, id in enumerate(ids):
189
- if id in self._inv_special_tokens:
190
- text += self.tokenizer.decode_ids(ids[last_i:i]) + " "
191
- text += self._inv_special_tokens[id] + " "
192
- last_i = i + 1
193
-
194
- text += self.tokenizer.decode_ids(ids[last_i:])
195
- return text
196
-
197
- @property
198
- def cls(self):
199
- return self._cls_id
200
-
201
- @property
202
- def sep(self):
203
- return self._sep_id
204
-
205
- @property
206
- def pad(self):
207
- return self._pad_id
208
-
209
- @property
210
- def bos_token_id(self):
211
- return self._bos_id
212
-
213
- @property
214
- def bos(self):
215
- return self._bos_id
216
-
217
- @property
218
- def eod(self):
219
- return self._eod_id
220
-
221
- @property
222
- def eos_token_id(self):
223
- return self._eos_id
224
-
225
- @property
226
- def eos(self):
227
- return self._eos_id
228
-
229
- @property
230
- def mask(self):
231
- return self._mask_id
232
-
233
- @property
234
- def additional_special_tokens_ids(self):
235
- return [self.vocab[k] for k in self._t5_tokens]
236
-
237
- class _MMSentencePieceTokenizer(_SentencePieceTokenizer):
238
- """SentencePieceTokenizer-Megatron wrapper"""
239
-
240
- def __init__(self, model_file, vocab_extra_ids=0):
241
- super().__init__(model_file, vocab_extra_ids)
242
-
243
-
244
- def _initalize(self, vocab_extra_ids):
245
- self._populate_vocab()
246
- self._special_tokens = {}
247
- self._inv_special_tokens = {}
248
-
249
- self._t5_tokens = []
250
-
251
- def _add_special_token(t):
252
- if t not in self._vocab:
253
- next_id = len(self._vocab)
254
- self._vocab[t] = next_id
255
- self._inv_vocab[next_id] = t
256
- self._special_tokens[t] = self._vocab[t]
257
- self._inv_special_tokens[self._vocab[t]] = t
258
-
259
- _add_special_token('<CLS>')
260
- self._cls_id = self._vocab['<CLS>']
261
- _add_special_token('<SEP>')
262
- self._sep_id = self._vocab['<SEP>']
263
- _add_special_token('<EOD>')
264
- self._eod_id = self._vocab['<EOD>']
265
- _add_special_token('<MASK>')
266
- self._mask_id = self._vocab['<MASK>']
267
-
268
- _add_special_token('<SOA>')
269
- self._soa_id = self._vocab['<SOA>']
270
- _add_special_token('<EOA>')
271
- self._eoa_id = self._vocab['<EOA>']
272
- _add_special_token('<SOV>')
273
- self._sov_id = self._vocab['<SOV>']
274
- _add_special_token('<EOV>')
275
- self._eov_id = self._vocab['<EOV>']
276
- _add_special_token('<SOI>')
277
- self._soi_id = self._vocab['<SOI>']
278
- _add_special_token('<EOI>')
279
- self._eoi_id = self._vocab['<EOI>']
280
- _add_special_token('<s_local>')
281
- self._s_local_id = self._vocab['<s_local>']
282
- _add_special_token('<e_local>')
283
- self._e_local_id = self._vocab['<e_local>']
284
- _add_special_token('<s_global>')
285
- self._s_global_id = self._vocab['<s_global>']
286
- _add_special_token('<e_global>')
287
- self._e_global_id = self._vocab['<e_global>']
288
- _add_special_token('<stage_1>')
289
- self._stage_1_id = self._vocab['<stage_1>']
290
- _add_special_token('<stage_2>')
291
- self._stage_2_id = self._vocab['<stage_2>']
292
- pad_id = self.tokenizer.pad_id()
293
- try:
294
- pad_token = self.tokenizer.id_to_piece(pad_id)
295
- except IndexError:
296
- pad_token = '<PAD>'
297
- _add_special_token(pad_token)
298
- self._pad_id = self._vocab[pad_token]
299
-
300
- bos_id = self.tokenizer.bos_id()
301
- try:
302
- bos_token = self.tokenizer.id_to_piece(bos_id)
303
- except IndexError:
304
- bos_token = '<BOS>'
305
- _add_special_token(bos_token)
306
- self._bos_id = self._vocab[bos_token]
307
-
308
- eos_id = self.tokenizer.eos_id()
309
- try:
310
- eos_token = self.tokenizer.id_to_piece(eos_id)
311
- except IndexError:
312
- eos_token = '<EOS>'
313
- _add_special_token(eos_token)
314
- self._eos_id = self._vocab[eos_token]
315
-
316
- for i in range(vocab_extra_ids):
317
- t = "<extra_id_{}>".format(i)
318
- _add_special_token(t)
319
- self._t5_tokens += [t]
320
-
321
- @property
322
- def soa(self):
323
- return self._soa_id
324
-
325
- @property
326
- def eoa(self):
327
- return self._eoa_id
328
-
329
- @property
330
- def sov(self):
331
- return self._sov_id
332
-
333
- @property
334
- def eov(self):
335
- return self._eov_id
336
-
337
- @property
338
- def soi(self):
339
- return self._soi_id
340
-
341
- @property
342
- def eoi(self):
343
- return self._eoi_id
344
-
345
- @property
346
- def s_local(self):
347
- return self._s_local_id
348
-
349
- @property
350
- def e_local(self):
351
- return self._e_local_id
352
-
353
- @property
354
- def s_global(self):
355
- return self._s_global_id
356
-
357
- @property
358
- def e_global(self):
359
- return self._e_global_id
360
-
361
- @property
362
- def stage_1(self):
363
- return self._stage_1_id
364
-
365
- @property
366
- def stage_2(self):
367
- return self._stage_2_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/prompt_examples/genre.txt DELETED
@@ -1 +0,0 @@
1
- inspiring female uplifting pop airy vocal electronic bright vocal vocal
 
 
inference/prompt_examples/lyrics.txt DELETED
@@ -1,39 +0,0 @@
1
- [verse]
2
- Staring at the sunset, colors paint the sky
3
- Thoughts of you keep swirling, can't deny
4
- I know I let you down, I made mistakes
5
- But I'm here to mend the heart I didn't break
6
-
7
- [chorus]
8
- Every road you take, I'll be one step behind
9
- Every dream you chase, I'm reaching for the light
10
- You can't fight this feeling now
11
- I won't back down
12
- You know you can't deny it now
13
- I won't back down
14
-
15
- [verse]
16
- They might say I'm foolish, chasing after you
17
- But they don't feel this love the way we do
18
- My heart beats only for you, can't you see?
19
- I won't let you slip away from me
20
-
21
- [chorus]
22
- Every road you take, I'll be one step behind
23
- Every dream you chase, I'm reaching for the light
24
- You can't fight this feeling now
25
- I won't back down
26
- You know you can't deny it now
27
- I won't back down
28
-
29
- [bridge]
30
- No, I won't back down, won't turn around
31
- Until you're back where you belong
32
- I'll cross the oceans wide, stand by your side
33
- Together we are strong
34
-
35
- [outro]
36
- Every road you take, I'll be one step behind
37
- Every dream you chase, love's the tie that binds
38
- You can't fight this feeling now
39
- I won't back down