mostafaashahin's picture
Update app.py
e7ae2d2
from os.path import join
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer
import torch
import pandas as pd
import librosa
import gradio as gr
from gradio.components import Audio, Dropdown, Textbox
Attributes = {'Dental':2,
'Labial':4,
'Consonant':15,
'Vowel':19,
'Fricative':21,
'Nasal':22,
'Stop':23,
'Affricate':25,
'Voiced':31,
'Bilabial':32,
}
#define groups
#make sure that all phonemes covered in each group
g1 = ['p_alveolar','n_alveolar']
g2 = ['p_palatal','n_palatal']
g3 = ['p_dental','n_dental']
g4 = ['p_glottal','n_glottal']
g5 = ['p_labial','n_labial']
g6 = ['p_velar','n_velar']
g7 = ['p_anterior','n_anterior']
g8 = ['p_posterior','n_posterior']
g9 = ['p_retroflex','n_retroflex']
g10 = ['p_mid','n_mid']
g11 = ['p_high_v','n_high_v']
g12 = ['p_low','n_low']
g13 = ['p_front','n_front']
g14 = ['p_back','n_back']
g15 = ['p_central','n_central']
g16 = ['p_consonant','n_consonant']
g17 = ['p_sonorant','n_sonorant']
g18 = ['p_long','n_long']
g19 = ['p_short','n_short']
g20 = ['p_vowel','n_vowel']
g21 = ['p_semivowel','n_semivowel']
g22 = ['p_fricative','n_fricative']
g23 = ['p_nasal','n_nasal']
g24 = ['p_stop','n_stop']
g25 = ['p_approximant','n_approximant']
g26 = ['p_affricate','n_affricate']
g27 = ['p_liquid','n_liquid']
g28 = ['p_continuant','n_continuant']
g29 = ['p_monophthong','n_monophthong']
g30 = ['p_diphthong','n_diphthong']
g31 = ['p_round','n_round']
g32 = ['p_voiced','n_voiced']
g33 = ['p_bilabial','n_bilabial']
g34 = ['p_coronal','n_coronal']
g35 = ['p_dorsal','n_dorsal']
groups = [g1,g2,g3,g4,g5,g6,g7,g8,g9,g10,g11,g12,g13,g14,g15,g16,g17,g18,g19,g20,g21,g22,g23,g24,g25,g26,g27,g28,g29,g30,g31,g32,g33,g34,g35]
model_dir = 'model/'
processor = Wav2Vec2Processor.from_pretrained(model_dir)
model = Wav2Vec2ForCTC.from_pretrained(model_dir)
tokenizer_phoneme = Wav2Vec2CTCTokenizer(join(model_dir,"phoneme_vocab.json"), pad_token="<pad>", word_delimiter_token="")
phoneme_list = list(tokenizer_phoneme.get_vocab().keys())
p_att = pd.read_csv(join(model_dir,"phonological_attributes_v12.csv"),index_col=0)
mappers = []
for g in groups:
p2att = {}
for att in g:
att_phs = p_att[p_att[att]==1].index
for ph in att_phs:
p2att[ph] = att
mappers.append(p2att)
p2att = torch.zeros((tokenizer_phoneme.vocab_size, processor.tokenizer.vocab_size)).type(torch.FloatTensor)
for p in phoneme_list:
for mapper in mappers:
if p == processor.tokenizer.pad_token:
p2att[tokenizer_phoneme.convert_tokens_to_ids(p),processor.tokenizer.pad_token_id] = 1
else:
p2att[tokenizer_phoneme.convert_tokens_to_ids(p), processor.tokenizer.convert_tokens_to_ids(mapper[p])] = 1
group_ids = [sorted(processor.tokenizer.convert_tokens_to_ids(group)) for group in groups]
group_ids = [dict([(x[0]+1,x[1]) for x in list(enumerate(g))]) for g in group_ids] #This is the inversion of the one used in training as here we need to map prediction back to original tokens
def masked_log_softmax(vector: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
if mask is not None:
mask = mask.float()
while mask.dim() < vector.dim():
mask = mask.unsqueeze(1)
# vector + mask.log() is an easy way to zero out masked elements in logspace, but it
# results in nans when the whole vector is masked. We need a very small value instead of a
# zero in the mask for these cases. log(1 + 1e-45) is still basically 0, so we can safely
# just add 1e-45 before calling mask.log(). We use 1e-45 because 1e-46 is so small it
# becomes 0 - this is just the smallest value we can actually use.
vector = vector + (mask + 1e-45).log()
return torch.nn.functional.log_softmax(vector, dim=dim)
def getPhonemes(logits):
ngroups = len(group_ids)
log_props_all_masked = []
for i in range(ngroups):
mask = torch.zeros(logits.size()[2], dtype = torch.bool)
mask[0] = True
mask[list(group_ids[i].values())] = True
mask.unsqueeze_(0).unsqueeze_(0)
log_probs = masked_log_softmax(vector=logits, mask=mask, dim=-1).masked_fill(~mask,0)
log_props_all_masked.append(log_probs)
log_probs_cat = torch.stack(log_props_all_masked, dim=0).sum(dim=0)
log_probs_phoneme = torch.matmul(p2att,log_probs_cat.transpose(1,2)).transpose(1,2).type(torch.FloatTensor)
pred_ids = torch.argmax(log_probs_phoneme,dim=-1)
pred = tokenizer_phoneme.batch_decode(pred_ids,spaces_between_special_tokens=True)[0]
return pred
def getAtt(logits,i):
mask = torch.zeros(logits.size()[2], dtype = torch.bool)
mask[0] = True
mask[list(group_ids[i].values())] = True
logits_g = logits[:,:,mask]
pred_ids = torch.argmax(logits_g,dim=-1)
pred_ids = pred_ids.cpu().apply_(lambda x: group_ids[i].get(x,x))
pred = processor.batch_decode(pred_ids,spaces_between_special_tokens=True)[0]
return pred.replace('p_','+').replace('n_','-')
def recognizeAudio(audio, mic_audioFilePath, att):
i = Attributes[att]
audio = mic_audioFilePath if mic_audioFilePath else audio
y, sr = librosa.load(audio, sr=16000)
input_values = processor(audio=y, sampling_rate=sr, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values).logits
return getPhonemes(logits), getAtt(logits,i)
gui = gr.Interface(fn=recognizeAudio, inputs=[Audio(label="Upload Audio File", type="filepath"),Audio(source="microphone", type="filepath", label="Record from microphone"),
Dropdown(choices=Attributes.keys(),type="value",label="Select Attribute")],
outputs=[Textbox(label="ARPA Phoneme"),Textbox(label="Attribute (+/-)")])
gui.launch()