Spaces:
Sleeping
Sleeping
from os.path import join | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer | |
import torch | |
import pandas as pd | |
import librosa | |
import gradio as gr | |
from gradio.components import Audio, Dropdown, Textbox | |
Attributes = {'Dental':2, | |
'Labial':4, | |
'Consonant':15, | |
'Vowel':19, | |
'Fricative':21, | |
'Nasal':22, | |
'Stop':23, | |
'Affricate':25, | |
'Voiced':31, | |
'Bilabial':32, | |
} | |
#define groups | |
#make sure that all phonemes covered in each group | |
g1 = ['p_alveolar','n_alveolar'] | |
g2 = ['p_palatal','n_palatal'] | |
g3 = ['p_dental','n_dental'] | |
g4 = ['p_glottal','n_glottal'] | |
g5 = ['p_labial','n_labial'] | |
g6 = ['p_velar','n_velar'] | |
g7 = ['p_anterior','n_anterior'] | |
g8 = ['p_posterior','n_posterior'] | |
g9 = ['p_retroflex','n_retroflex'] | |
g10 = ['p_mid','n_mid'] | |
g11 = ['p_high_v','n_high_v'] | |
g12 = ['p_low','n_low'] | |
g13 = ['p_front','n_front'] | |
g14 = ['p_back','n_back'] | |
g15 = ['p_central','n_central'] | |
g16 = ['p_consonant','n_consonant'] | |
g17 = ['p_sonorant','n_sonorant'] | |
g18 = ['p_long','n_long'] | |
g19 = ['p_short','n_short'] | |
g20 = ['p_vowel','n_vowel'] | |
g21 = ['p_semivowel','n_semivowel'] | |
g22 = ['p_fricative','n_fricative'] | |
g23 = ['p_nasal','n_nasal'] | |
g24 = ['p_stop','n_stop'] | |
g25 = ['p_approximant','n_approximant'] | |
g26 = ['p_affricate','n_affricate'] | |
g27 = ['p_liquid','n_liquid'] | |
g28 = ['p_continuant','n_continuant'] | |
g29 = ['p_monophthong','n_monophthong'] | |
g30 = ['p_diphthong','n_diphthong'] | |
g31 = ['p_round','n_round'] | |
g32 = ['p_voiced','n_voiced'] | |
g33 = ['p_bilabial','n_bilabial'] | |
g34 = ['p_coronal','n_coronal'] | |
g35 = ['p_dorsal','n_dorsal'] | |
groups = [g1,g2,g3,g4,g5,g6,g7,g8,g9,g10,g11,g12,g13,g14,g15,g16,g17,g18,g19,g20,g21,g22,g23,g24,g25,g26,g27,g28,g29,g30,g31,g32,g33,g34,g35] | |
model_dir = 'model/' | |
processor = Wav2Vec2Processor.from_pretrained(model_dir) | |
model = Wav2Vec2ForCTC.from_pretrained(model_dir) | |
tokenizer_phoneme = Wav2Vec2CTCTokenizer(join(model_dir,"phoneme_vocab.json"), pad_token="<pad>", word_delimiter_token="") | |
phoneme_list = list(tokenizer_phoneme.get_vocab().keys()) | |
p_att = pd.read_csv(join(model_dir,"phonological_attributes_v12.csv"),index_col=0) | |
mappers = [] | |
for g in groups: | |
p2att = {} | |
for att in g: | |
att_phs = p_att[p_att[att]==1].index | |
for ph in att_phs: | |
p2att[ph] = att | |
mappers.append(p2att) | |
p2att = torch.zeros((tokenizer_phoneme.vocab_size, processor.tokenizer.vocab_size)).type(torch.FloatTensor) | |
for p in phoneme_list: | |
for mapper in mappers: | |
if p == processor.tokenizer.pad_token: | |
p2att[tokenizer_phoneme.convert_tokens_to_ids(p),processor.tokenizer.pad_token_id] = 1 | |
else: | |
p2att[tokenizer_phoneme.convert_tokens_to_ids(p), processor.tokenizer.convert_tokens_to_ids(mapper[p])] = 1 | |
group_ids = [sorted(processor.tokenizer.convert_tokens_to_ids(group)) for group in groups] | |
group_ids = [dict([(x[0]+1,x[1]) for x in list(enumerate(g))]) for g in group_ids] #This is the inversion of the one used in training as here we need to map prediction back to original tokens | |
def masked_log_softmax(vector: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: | |
if mask is not None: | |
mask = mask.float() | |
while mask.dim() < vector.dim(): | |
mask = mask.unsqueeze(1) | |
# vector + mask.log() is an easy way to zero out masked elements in logspace, but it | |
# results in nans when the whole vector is masked. We need a very small value instead of a | |
# zero in the mask for these cases. log(1 + 1e-45) is still basically 0, so we can safely | |
# just add 1e-45 before calling mask.log(). We use 1e-45 because 1e-46 is so small it | |
# becomes 0 - this is just the smallest value we can actually use. | |
vector = vector + (mask + 1e-45).log() | |
return torch.nn.functional.log_softmax(vector, dim=dim) | |
def getPhonemes(logits): | |
ngroups = len(group_ids) | |
log_props_all_masked = [] | |
for i in range(ngroups): | |
mask = torch.zeros(logits.size()[2], dtype = torch.bool) | |
mask[0] = True | |
mask[list(group_ids[i].values())] = True | |
mask.unsqueeze_(0).unsqueeze_(0) | |
log_probs = masked_log_softmax(vector=logits, mask=mask, dim=-1).masked_fill(~mask,0) | |
log_props_all_masked.append(log_probs) | |
log_probs_cat = torch.stack(log_props_all_masked, dim=0).sum(dim=0) | |
log_probs_phoneme = torch.matmul(p2att,log_probs_cat.transpose(1,2)).transpose(1,2).type(torch.FloatTensor) | |
pred_ids = torch.argmax(log_probs_phoneme,dim=-1) | |
pred = tokenizer_phoneme.batch_decode(pred_ids,spaces_between_special_tokens=True)[0] | |
return pred | |
def getAtt(logits,i): | |
mask = torch.zeros(logits.size()[2], dtype = torch.bool) | |
mask[0] = True | |
mask[list(group_ids[i].values())] = True | |
logits_g = logits[:,:,mask] | |
pred_ids = torch.argmax(logits_g,dim=-1) | |
pred_ids = pred_ids.cpu().apply_(lambda x: group_ids[i].get(x,x)) | |
pred = processor.batch_decode(pred_ids,spaces_between_special_tokens=True)[0] | |
return pred.replace('p_','+').replace('n_','-') | |
def recognizeAudio(audio, mic_audioFilePath, att): | |
i = Attributes[att] | |
audio = mic_audioFilePath if mic_audioFilePath else audio | |
y, sr = librosa.load(audio, sr=16000) | |
input_values = processor(audio=y, sampling_rate=sr, return_tensors="pt").input_values | |
with torch.no_grad(): | |
logits = model(input_values).logits | |
return getPhonemes(logits), getAtt(logits,i) | |
gui = gr.Interface(fn=recognizeAudio, inputs=[Audio(label="Upload Audio File", type="filepath"),Audio(source="microphone", type="filepath", label="Record from microphone"), | |
Dropdown(choices=Attributes.keys(),type="value",label="Select Attribute")], | |
outputs=[Textbox(label="ARPA Phoneme"),Textbox(label="Attribute (+/-)")]) | |
gui.launch() | |