VITA-Audio / vita_audio /data /processor /audio_processor.py
shenyunhang's picture
-a
82f2cfa
import json
import math
import os
import torch
import natsort
from vita_audio.tokenizer import get_audio_tokenizer
class AudioProcessor:
def __init__(
self,
audio_tokenizer_path=None,
audio_tokenizer_type=None,
text_audio_interval_ratio=None,
):
self.audio_tokenizer = get_audio_tokenizer(
audio_tokenizer_path,
audio_tokenizer_type,
)
self.audio_tokenizer_type = audio_tokenizer_type
self.text_audio_interval_ratio = text_audio_interval_ratio
# self.load_model()
def load_model(self):
if self.audio_tokenizer is not None:
self.audio_tokenizer.load_model()
def process_audios(self, audio_path, is_discrete=False, is_contiguous=False, **kwargs):
assert not (is_discrete and is_contiguous)
assert is_discrete or is_contiguous
if is_discrete:
audio_tokenizer_type = self.audio_tokenizer_type.split("_")[-1]
cache_path = os.path.splitext(audio_path)[0] + f"_{audio_tokenizer_type}.json"
try:
if os.path.isfile(cache_path):
with open(cache_path, "r") as f:
audio_data = json.load(f)
return audio_data
except Exception as e:
pass
audio_data = self.audio_tokenizer.encode(
audio_path, is_discrete=is_discrete, is_contiguous=is_contiguous, **kwargs
)
# print(f"{len(audio_data)=}")
if is_discrete:
try:
if isinstance(audio_data, list):
with open(cache_path, "w") as f:
json.dump(audio_data, f)
except Exception as e:
pass
return audio_data
@property
def is_discrete(self):
return self.audio_tokenizer.is_discrete
@property
def is_contiguous(self):
return self.audio_tokenizer.is_contiguous
def apply_to_role(self, role, **kwargs):
return self.audio_tokenizer.apply_to_role(role, **kwargs)
def text_audio_interval(self, content_input_id, AUD_START_ID, AUD_END_ID):
return text_audio_interval(
content_input_id,
AUD_START_ID,
AUD_END_ID,
self.text_audio_interval_ratio,
)
def add_audio_input_contiguous(input_ids, audio_paths, tokenizer, audio_tokenizer):
from ...constants import (
AUD_START_TOKEN,
AUD_END_TOKEN,
AUD_TAG_TOKEN,
AUD_CONTEXT_TOKEN,
)
AUD_CONTEXT_ID = tokenizer(AUD_CONTEXT_TOKEN, add_special_tokens=False).input_ids
AUD_TAG_ID = tokenizer(AUD_TAG_TOKEN, add_special_tokens=False).input_ids
AUD_START_ID = tokenizer(AUD_START_TOKEN, add_special_tokens=False).input_ids
AUD_END_ID = tokenizer(AUD_END_TOKEN, add_special_tokens=False).input_ids
AUD_CONTEXT_ID = AUD_CONTEXT_ID[0]
AUD_TAG_ID = AUD_TAG_ID[0]
AUD_START_ID = AUD_START_ID[0]
AUD_END_ID = AUD_END_ID[0]
aud_positions = [i for i, x in enumerate(input_ids) if x == AUD_TAG_ID]
audios = []
audio_indices = []
new_input_ids = []
st = 0
for aud_idx, aud_pos in enumerate(aud_positions):
audio = audio_tokenizer.encode(audio_paths[aud_idx], is_contiguous=True)
audios.append(audio)
audio_token_length = audio.size(0) + 4
new_input_ids += input_ids[st:aud_pos]
new_input_ids += [AUD_START_ID]
audio_indice_b = torch.zeros(
1, audio_token_length, dtype=torch.int64
) # This will change in collate_fn
audio_indice_s = (
torch.arange(len(new_input_ids), len(new_input_ids) + audio_token_length)
.unsqueeze(0)
.repeat(1, 1)
)
audio_indice_b_s = torch.stack(
[audio_indice_b, audio_indice_s], dim=0
) # 2, num_image, image_length
audio_indices.append(audio_indice_b_s)
new_input_ids += [AUD_CONTEXT_ID] * audio_token_length
new_input_ids += [AUD_END_ID]
st = aud_pos + 1
new_input_ids += input_ids[st:]
inputs_ids = new_input_ids
return inputs_ids, audios, audio_indices
def text_audio_interval_old(input_ids, AUD_START_ID, AUD_END_ID, text_audio_interval_ratio):
if text_audio_interval_ratio is not None:
text_num, audio_num = text_audio_interval_ratio
else:
text_num = 13
audio_num = 26
text_num = 4
audio_num = 10
# exclude AUD_START and AUD_END
audio_num = audio_num - 2
st = [i for i, x in enumerate(input_ids) if x == AUD_START_ID]
ed = [i for i, x in enumerate(input_ids) if x == AUD_END_ID]
# only text
if len(st) == 0 and len(ed) == 0:
return input_ids
assert len(st) == 1
assert len(ed) == 1
st = st[0]
ed = ed[0]
assert st < ed
# only audio
if st == 0 and ed == len(input_ids) - 1:
return input_ids
audio_tokens = input_ids[st + 1 : ed]
text_tokens = input_ids[:st] + input_ids[ed + 1 :]
if False:
audio_tokens_chunks = [
audio_tokens[i : i + audio_num] for i in range(0, len(audio_tokens), audio_num)
]
text_tokens_chunks = [
text_tokens[i : i + text_num] for i in range(0, len(text_tokens), text_num)
]
if False:
# [0 1] [2 3 4 5 6 audio_num-1] ...
audio_tokens_chunks = [audio_tokens[:2], audio_tokens[2:audio_num]] + [
audio_tokens[i : i + audio_num] for i in range(audio_num, len(audio_tokens), audio_num)
]
# [0] [1 2 text_num-1] ...
text_tokens_chunks = [text_tokens[:1], text_tokens[1:text_num]] + [
text_tokens[i : i + text_num] for i in range(text_num, len(text_tokens), text_num)
]
if True:
# [0 1 2 3 4 5 6 audio_num] [] ...
audio_tokens_chunks = [audio_tokens[:audio_num]] + [
audio_tokens[i : i + audio_num] for i in range(audio_num, len(audio_tokens), audio_num)
]
# [0] [] ...
text_tokens_chunks = [text_tokens[:1]] + [
text_tokens[i : i + text_num] for i in range(1, len(text_tokens), text_num)
]
chunk_num = min(len(audio_tokens_chunks), len(text_tokens_chunks))
audio_tokens_chunks = audio_tokens_chunks[: chunk_num - 1] + [
sum(audio_tokens_chunks[chunk_num - 1 :], [])
]
text_tokens_chunks = text_tokens_chunks[: chunk_num - 1] + [
sum(text_tokens_chunks[chunk_num - 1 :], [])
]
interval_input_ids = []
for text_tokens, audio_tokens in zip(text_tokens_chunks, audio_tokens_chunks):
interval_input_ids += text_tokens + [AUD_START_ID] + audio_tokens + [AUD_END_ID]
# interval_input_ids += text_tokens + audio_tokens
return interval_input_ids
def text_audio_interval(input_ids, AUD_START_ID, AUD_END_ID, text_audio_interval_ratio):
if text_audio_interval_ratio is None:
# T A
text_audio_interval_ratio = [13, 26]
# T A T A T A
text_audio_interval_ratio = [1, 4, 3, 8, 4, 10]
# T A T A
text_audio_interval_ratio = [1, 10, 4, 10]
text_nums = text_audio_interval_ratio[::2]
audio_nums = text_audio_interval_ratio[1::2]
# exclude AUD_START and AUD_END
audio_nums = [x - 2 for x in audio_nums]
st = [i for i, x in enumerate(input_ids) if x == AUD_START_ID]
ed = [i for i, x in enumerate(input_ids) if x == AUD_END_ID]
# only text
if len(st) == 0 and len(ed) == 0:
return input_ids
assert len(st) == 1
assert len(ed) == 1
st = st[0]
ed = ed[0]
assert st < ed
# only audio
if st == 0 and ed == len(input_ids) - 1:
return input_ids
audio_tokens = input_ids[st + 1 : ed]
text_tokens = input_ids[:st] + input_ids[ed + 1 :]
audio_tokens_chunks = []
while len(audio_tokens) > 0:
if len(audio_nums) > 1:
audio_num = audio_nums.pop(0)
else:
audio_num = audio_nums[0]
audio_tokens_chunks.append(audio_tokens[:audio_num])
audio_tokens = audio_tokens[audio_num:]
text_tokens_chunks = []
while len(text_tokens) > 0:
if len(text_nums) > 1:
text_num = text_nums.pop(0)
else:
text_num = text_nums[0]
text_tokens_chunks.append(text_tokens[:text_num])
text_tokens = text_tokens[text_num:]
chunk_num = min(len(audio_tokens_chunks), len(text_tokens_chunks))
audio_tokens_chunks = audio_tokens_chunks[: chunk_num - 1] + [
sum(audio_tokens_chunks[chunk_num - 1 :], [])
]
text_tokens_chunks = text_tokens_chunks[: chunk_num - 1] + [
sum(text_tokens_chunks[chunk_num - 1 :], [])
]
interval_input_ids = []
for text_tokens, audio_tokens in zip(text_tokens_chunks, audio_tokens_chunks):
interval_input_ids += text_tokens + [AUD_START_ID] + audio_tokens + [AUD_END_ID]
# interval_input_ids += text_tokens + audio_tokens
return interval_input_ids