Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2025 ByteDance and/or its affiliates. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torch.nn.functional as F | |
def map_phone_to_tokendict(item, pad_bos_eos=True): | |
# Merge Chinese phone and tone (Original dict ends at 173, i.e., ph_dict_size=173). 146~173 is punctuations. | |
phone = item['txt_token'].clone() | |
merged_phone = item['txt_token'].clone() | |
tone_tmp = item['tone'].clone() | |
# In tone_dict, tone_1 is 4, tone_2 is 11, tone_3 is 12, tone_4 is 13, tone_5 is 14, tone_6 is 15 | |
tone_tmp[tone_tmp==4] = 1 | |
tone_tmp[tone_tmp==11] = 2 | |
tone_tmp[tone_tmp==12] = 3 | |
tone_tmp[tone_tmp==13] = 4 | |
tone_tmp[tone_tmp==14] = 5 | |
tone_tmp[tone_tmp==15] = 6 | |
# Chinese phones lie in 3~100 in the phone_dict, we map them to 200~788 | |
ch_phone_idx = (phone >= 3) & (phone <= 100) | |
merged_phone[ch_phone_idx] = (merged_phone[ch_phone_idx] - 3) * 6 + 200 + tone_tmp[ch_phone_idx] | |
if pad_bos_eos: | |
merged_phone = F.pad(merged_phone, (1, 0), mode='constant', value=798) | |
merged_phone = F.pad(merged_phone, (0, 1), mode='constant', value=799) | |
return merged_phone | |
def split_ph_timestamp(ph_timestamp): | |
''' Input: ph_timestamp, shape [T] ''' | |
# Map the timestamp of each phone back to its original frame-level lengths | |
ph_timestamp[ph_timestamp >= 800] -= 800 | |
ph_list = [] | |
tone_list = [] | |
dur_list = [] | |
cur_timestamp = 0 | |
for idx, item in enumerate(ph_timestamp): | |
if idx % 2 == 0: | |
# Map Chinese phones back to its original phone_dict | |
if (200 <= item <= 788): | |
ph = (item - 200 - 1) // 6 + 3 | |
tone = (item - 200 - 1) % 6 + 1 | |
if tone == 1: | |
tone = 4 | |
else: | |
tone = tone + 9 | |
# Set English tone to '3' | |
else: | |
ph = item | |
tone = 3 | |
ph_list.append(ph) | |
tone_list.append(tone) | |
else: | |
dur_list.append((item - cur_timestamp)) | |
cur_timestamp = item | |
assert len(ph_list) == len(dur_list), f"{len(ph_list)}, {len(dur_list)}" | |
ph_seq, tone_seq, dur_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list), torch.LongTensor(dur_list) | |
return ph_seq, tone_seq, dur_seq, ph_timestamp[-1] | |
def split_ph(ph_seq): | |
''' Input: ph_timestamp, shape [T] ''' | |
ph_list = [] | |
tone_list = [] | |
for idx, item in enumerate(ph_seq): | |
# Map Chinese phones back to its original phone_dict | |
if (200 <= item <= 788): | |
ph = (item - 200 - 1) // 6 + 3 | |
tone = (item - 200 - 1) % 6 + 1 | |
if tone == 1: | |
tone = 4 | |
else: | |
tone = tone + 9 | |
# Set English tone to '3' | |
else: | |
ph = item | |
tone = 3 | |
ph_list.append(ph) | |
tone_list.append(tone) | |
assert len(ph_list) == len(tone_list) | |
ph_seq, tone_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list) | |
return ph_seq, tone_seq |