Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import math | |
import os | |
import torch | |
import natsort | |
from vita_audio.tokenizer import get_audio_tokenizer | |
class AudioProcessor: | |
def __init__( | |
self, | |
audio_tokenizer_path=None, | |
audio_tokenizer_type=None, | |
text_audio_interval_ratio=None, | |
): | |
self.audio_tokenizer = get_audio_tokenizer( | |
audio_tokenizer_path, | |
audio_tokenizer_type, | |
) | |
self.audio_tokenizer_type = audio_tokenizer_type | |
self.text_audio_interval_ratio = text_audio_interval_ratio | |
# self.load_model() | |
def load_model(self): | |
if self.audio_tokenizer is not None: | |
self.audio_tokenizer.load_model() | |
def process_audios(self, audio_path, is_discrete=False, is_contiguous=False, **kwargs): | |
assert not (is_discrete and is_contiguous) | |
assert is_discrete or is_contiguous | |
if is_discrete: | |
audio_tokenizer_type = self.audio_tokenizer_type.split("_")[-1] | |
cache_path = os.path.splitext(audio_path)[0] + f"_{audio_tokenizer_type}.json" | |
try: | |
if os.path.isfile(cache_path): | |
with open(cache_path, "r") as f: | |
audio_data = json.load(f) | |
return audio_data | |
except Exception as e: | |
pass | |
audio_data = self.audio_tokenizer.encode( | |
audio_path, is_discrete=is_discrete, is_contiguous=is_contiguous, **kwargs | |
) | |
# print(f"{len(audio_data)=}") | |
if is_discrete: | |
try: | |
if isinstance(audio_data, list): | |
with open(cache_path, "w") as f: | |
json.dump(audio_data, f) | |
except Exception as e: | |
pass | |
return audio_data | |
def is_discrete(self): | |
return self.audio_tokenizer.is_discrete | |
def is_contiguous(self): | |
return self.audio_tokenizer.is_contiguous | |
def apply_to_role(self, role, **kwargs): | |
return self.audio_tokenizer.apply_to_role(role, **kwargs) | |
def text_audio_interval(self, content_input_id, AUD_START_ID, AUD_END_ID): | |
return text_audio_interval( | |
content_input_id, | |
AUD_START_ID, | |
AUD_END_ID, | |
self.text_audio_interval_ratio, | |
) | |
def add_audio_input_contiguous(input_ids, audio_paths, tokenizer, audio_tokenizer): | |
from ...constants import ( | |
AUD_START_TOKEN, | |
AUD_END_TOKEN, | |
AUD_TAG_TOKEN, | |
AUD_CONTEXT_TOKEN, | |
) | |
AUD_CONTEXT_ID = tokenizer(AUD_CONTEXT_TOKEN, add_special_tokens=False).input_ids | |
AUD_TAG_ID = tokenizer(AUD_TAG_TOKEN, add_special_tokens=False).input_ids | |
AUD_START_ID = tokenizer(AUD_START_TOKEN, add_special_tokens=False).input_ids | |
AUD_END_ID = tokenizer(AUD_END_TOKEN, add_special_tokens=False).input_ids | |
AUD_CONTEXT_ID = AUD_CONTEXT_ID[0] | |
AUD_TAG_ID = AUD_TAG_ID[0] | |
AUD_START_ID = AUD_START_ID[0] | |
AUD_END_ID = AUD_END_ID[0] | |
aud_positions = [i for i, x in enumerate(input_ids) if x == AUD_TAG_ID] | |
audios = [] | |
audio_indices = [] | |
new_input_ids = [] | |
st = 0 | |
for aud_idx, aud_pos in enumerate(aud_positions): | |
audio = audio_tokenizer.encode(audio_paths[aud_idx], is_contiguous=True) | |
audios.append(audio) | |
audio_token_length = audio.size(0) + 4 | |
new_input_ids += input_ids[st:aud_pos] | |
new_input_ids += [AUD_START_ID] | |
audio_indice_b = torch.zeros( | |
1, audio_token_length, dtype=torch.int64 | |
) # This will change in collate_fn | |
audio_indice_s = ( | |
torch.arange(len(new_input_ids), len(new_input_ids) + audio_token_length) | |
.unsqueeze(0) | |
.repeat(1, 1) | |
) | |
audio_indice_b_s = torch.stack( | |
[audio_indice_b, audio_indice_s], dim=0 | |
) # 2, num_image, image_length | |
audio_indices.append(audio_indice_b_s) | |
new_input_ids += [AUD_CONTEXT_ID] * audio_token_length | |
new_input_ids += [AUD_END_ID] | |
st = aud_pos + 1 | |
new_input_ids += input_ids[st:] | |
inputs_ids = new_input_ids | |
return inputs_ids, audios, audio_indices | |
def text_audio_interval_old(input_ids, AUD_START_ID, AUD_END_ID, text_audio_interval_ratio): | |
if text_audio_interval_ratio is not None: | |
text_num, audio_num = text_audio_interval_ratio | |
else: | |
text_num = 13 | |
audio_num = 26 | |
text_num = 4 | |
audio_num = 10 | |
# exclude AUD_START and AUD_END | |
audio_num = audio_num - 2 | |
st = [i for i, x in enumerate(input_ids) if x == AUD_START_ID] | |
ed = [i for i, x in enumerate(input_ids) if x == AUD_END_ID] | |
# only text | |
if len(st) == 0 and len(ed) == 0: | |
return input_ids | |
assert len(st) == 1 | |
assert len(ed) == 1 | |
st = st[0] | |
ed = ed[0] | |
assert st < ed | |
# only audio | |
if st == 0 and ed == len(input_ids) - 1: | |
return input_ids | |
audio_tokens = input_ids[st + 1 : ed] | |
text_tokens = input_ids[:st] + input_ids[ed + 1 :] | |
if False: | |
audio_tokens_chunks = [ | |
audio_tokens[i : i + audio_num] for i in range(0, len(audio_tokens), audio_num) | |
] | |
text_tokens_chunks = [ | |
text_tokens[i : i + text_num] for i in range(0, len(text_tokens), text_num) | |
] | |
if False: | |
# [0 1] [2 3 4 5 6 audio_num-1] ... | |
audio_tokens_chunks = [audio_tokens[:2], audio_tokens[2:audio_num]] + [ | |
audio_tokens[i : i + audio_num] for i in range(audio_num, len(audio_tokens), audio_num) | |
] | |
# [0] [1 2 text_num-1] ... | |
text_tokens_chunks = [text_tokens[:1], text_tokens[1:text_num]] + [ | |
text_tokens[i : i + text_num] for i in range(text_num, len(text_tokens), text_num) | |
] | |
if True: | |
# [0 1 2 3 4 5 6 audio_num] [] ... | |
audio_tokens_chunks = [audio_tokens[:audio_num]] + [ | |
audio_tokens[i : i + audio_num] for i in range(audio_num, len(audio_tokens), audio_num) | |
] | |
# [0] [] ... | |
text_tokens_chunks = [text_tokens[:1]] + [ | |
text_tokens[i : i + text_num] for i in range(1, len(text_tokens), text_num) | |
] | |
chunk_num = min(len(audio_tokens_chunks), len(text_tokens_chunks)) | |
audio_tokens_chunks = audio_tokens_chunks[: chunk_num - 1] + [ | |
sum(audio_tokens_chunks[chunk_num - 1 :], []) | |
] | |
text_tokens_chunks = text_tokens_chunks[: chunk_num - 1] + [ | |
sum(text_tokens_chunks[chunk_num - 1 :], []) | |
] | |
interval_input_ids = [] | |
for text_tokens, audio_tokens in zip(text_tokens_chunks, audio_tokens_chunks): | |
interval_input_ids += text_tokens + [AUD_START_ID] + audio_tokens + [AUD_END_ID] | |
# interval_input_ids += text_tokens + audio_tokens | |
return interval_input_ids | |
def text_audio_interval(input_ids, AUD_START_ID, AUD_END_ID, text_audio_interval_ratio): | |
if text_audio_interval_ratio is None: | |
# T A | |
text_audio_interval_ratio = [13, 26] | |
# T A T A T A | |
text_audio_interval_ratio = [1, 4, 3, 8, 4, 10] | |
# T A T A | |
text_audio_interval_ratio = [1, 10, 4, 10] | |
text_nums = text_audio_interval_ratio[::2] | |
audio_nums = text_audio_interval_ratio[1::2] | |
# exclude AUD_START and AUD_END | |
audio_nums = [x - 2 for x in audio_nums] | |
st = [i for i, x in enumerate(input_ids) if x == AUD_START_ID] | |
ed = [i for i, x in enumerate(input_ids) if x == AUD_END_ID] | |
# only text | |
if len(st) == 0 and len(ed) == 0: | |
return input_ids | |
assert len(st) == 1 | |
assert len(ed) == 1 | |
st = st[0] | |
ed = ed[0] | |
assert st < ed | |
# only audio | |
if st == 0 and ed == len(input_ids) - 1: | |
return input_ids | |
audio_tokens = input_ids[st + 1 : ed] | |
text_tokens = input_ids[:st] + input_ids[ed + 1 :] | |
audio_tokens_chunks = [] | |
while len(audio_tokens) > 0: | |
if len(audio_nums) > 1: | |
audio_num = audio_nums.pop(0) | |
else: | |
audio_num = audio_nums[0] | |
audio_tokens_chunks.append(audio_tokens[:audio_num]) | |
audio_tokens = audio_tokens[audio_num:] | |
text_tokens_chunks = [] | |
while len(text_tokens) > 0: | |
if len(text_nums) > 1: | |
text_num = text_nums.pop(0) | |
else: | |
text_num = text_nums[0] | |
text_tokens_chunks.append(text_tokens[:text_num]) | |
text_tokens = text_tokens[text_num:] | |
chunk_num = min(len(audio_tokens_chunks), len(text_tokens_chunks)) | |
audio_tokens_chunks = audio_tokens_chunks[: chunk_num - 1] + [ | |
sum(audio_tokens_chunks[chunk_num - 1 :], []) | |
] | |
text_tokens_chunks = text_tokens_chunks[: chunk_num - 1] + [ | |
sum(text_tokens_chunks[chunk_num - 1 :], []) | |
] | |
interval_input_ids = [] | |
for text_tokens, audio_tokens in zip(text_tokens_chunks, audio_tokens_chunks): | |
interval_input_ids += text_tokens + [AUD_START_ID] + audio_tokens + [AUD_END_ID] | |
# interval_input_ids += text_tokens + audio_tokens | |
return interval_input_ids | |