File size: 5,934 Bytes
e7ae2d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from os.path import join
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer
import torch
import pandas as pd
import librosa
import gradio as gr
from gradio.components import Audio, Dropdown, Textbox


Attributes = {'Dental':2,
              'Labial':4,
             'Consonant':15,
             'Vowel':19,
             'Fricative':21,
             'Nasal':22,
             'Stop':23,
             'Affricate':25,
             'Voiced':31,
             'Bilabial':32,
             }

#define groups
#make sure that all phonemes covered in each group
g1 = ['p_alveolar','n_alveolar']
g2 = ['p_palatal','n_palatal']
g3 = ['p_dental','n_dental']
g4 = ['p_glottal','n_glottal']
g5 = ['p_labial','n_labial']
g6 = ['p_velar','n_velar']
g7 = ['p_anterior','n_anterior']
g8 = ['p_posterior','n_posterior']
g9 = ['p_retroflex','n_retroflex']
g10 = ['p_mid','n_mid']
g11 = ['p_high_v','n_high_v']
g12 = ['p_low','n_low']
g13 = ['p_front','n_front']
g14 = ['p_back','n_back']
g15 = ['p_central','n_central']
g16 = ['p_consonant','n_consonant']
g17 = ['p_sonorant','n_sonorant']
g18 = ['p_long','n_long']
g19 = ['p_short','n_short']
g20 = ['p_vowel','n_vowel']
g21 = ['p_semivowel','n_semivowel']
g22 = ['p_fricative','n_fricative']
g23 = ['p_nasal','n_nasal']
g24 = ['p_stop','n_stop']
g25 = ['p_approximant','n_approximant']
g26 = ['p_affricate','n_affricate']
g27 = ['p_liquid','n_liquid']
g28 = ['p_continuant','n_continuant']
g29 = ['p_monophthong','n_monophthong']
g30 = ['p_diphthong','n_diphthong']
g31 = ['p_round','n_round']
g32 = ['p_voiced','n_voiced']
g33 = ['p_bilabial','n_bilabial']
g34 = ['p_coronal','n_coronal']
g35 = ['p_dorsal','n_dorsal']
groups = [g1,g2,g3,g4,g5,g6,g7,g8,g9,g10,g11,g12,g13,g14,g15,g16,g17,g18,g19,g20,g21,g22,g23,g24,g25,g26,g27,g28,g29,g30,g31,g32,g33,g34,g35]


model_dir = 'model/'
processor = Wav2Vec2Processor.from_pretrained(model_dir)
model = Wav2Vec2ForCTC.from_pretrained(model_dir)
tokenizer_phoneme = Wav2Vec2CTCTokenizer(join(model_dir,"phoneme_vocab.json"), pad_token="<pad>", word_delimiter_token="")

phoneme_list = list(tokenizer_phoneme.get_vocab().keys())
p_att = pd.read_csv(join(model_dir,"phonological_attributes_v12.csv"),index_col=0)

mappers = []
for g in groups:
    p2att = {}
    for att in g:
        att_phs = p_att[p_att[att]==1].index
        for ph in att_phs:
            p2att[ph] = att
    mappers.append(p2att)

p2att = torch.zeros((tokenizer_phoneme.vocab_size, processor.tokenizer.vocab_size)).type(torch.FloatTensor)

for p in phoneme_list:
    for mapper in mappers:
        if p == processor.tokenizer.pad_token:
            p2att[tokenizer_phoneme.convert_tokens_to_ids(p),processor.tokenizer.pad_token_id] = 1
        else:
            p2att[tokenizer_phoneme.convert_tokens_to_ids(p), processor.tokenizer.convert_tokens_to_ids(mapper[p])] = 1


group_ids = [sorted(processor.tokenizer.convert_tokens_to_ids(group)) for group in groups]
group_ids = [dict([(x[0]+1,x[1]) for x in list(enumerate(g))]) for g in group_ids] #This is the inversion of the one used in training as here we need to map prediction back to original tokens
def masked_log_softmax(vector: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
    if mask is not None:
        mask = mask.float()
        while mask.dim() < vector.dim():
            mask = mask.unsqueeze(1)
        # vector + mask.log() is an easy way to zero out masked elements in logspace, but it
        # results in nans when the whole vector is masked.  We need a very small value instead of a
        # zero in the mask for these cases.  log(1 + 1e-45) is still basically 0, so we can safely
        # just add 1e-45 before calling mask.log().  We use 1e-45 because 1e-46 is so small it
        # becomes 0 - this is just the smallest value we can actually use.
        vector = vector + (mask + 1e-45).log()
    return torch.nn.functional.log_softmax(vector, dim=dim)

def getPhonemes(logits):
    ngroups = len(group_ids)
    log_props_all_masked = []
    for i in range(ngroups):
        mask = torch.zeros(logits.size()[2], dtype = torch.bool)
        mask[0] = True
        mask[list(group_ids[i].values())] = True
        mask.unsqueeze_(0).unsqueeze_(0)
        log_probs = masked_log_softmax(vector=logits, mask=mask, dim=-1).masked_fill(~mask,0)
        log_props_all_masked.append(log_probs)
    log_probs_cat = torch.stack(log_props_all_masked, dim=0).sum(dim=0)
    log_probs_phoneme = torch.matmul(p2att,log_probs_cat.transpose(1,2)).transpose(1,2).type(torch.FloatTensor)
    pred_ids = torch.argmax(log_probs_phoneme,dim=-1)
    pred = tokenizer_phoneme.batch_decode(pred_ids,spaces_between_special_tokens=True)[0]
    return pred

def getAtt(logits,i):
    mask = torch.zeros(logits.size()[2], dtype = torch.bool)
    mask[0] = True
    mask[list(group_ids[i].values())] = True
    logits_g = logits[:,:,mask]
    pred_ids = torch.argmax(logits_g,dim=-1)
    pred_ids = pred_ids.cpu().apply_(lambda x: group_ids[i].get(x,x))
    pred = processor.batch_decode(pred_ids,spaces_between_special_tokens=True)[0]
    return pred.replace('p_','+').replace('n_','-')

def recognizeAudio(audio, mic_audioFilePath, att):
    i = Attributes[att]
    audio = mic_audioFilePath if mic_audioFilePath else audio
    y, sr = librosa.load(audio, sr=16000)
    input_values = processor(audio=y, sampling_rate=sr, return_tensors="pt").input_values
    with torch.no_grad():
        logits = model(input_values).logits
    return getPhonemes(logits), getAtt(logits,i)


gui = gr.Interface(fn=recognizeAudio, inputs=[Audio(label="Upload Audio File", type="filepath"),Audio(source="microphone", type="filepath", label="Record from microphone"), 
                                               Dropdown(choices=Attributes.keys(),type="value",label="Select Attribute")],
                    outputs=[Textbox(label="ARPA Phoneme"),Textbox(label="Attribute (+/-)")])
    
gui.launch()