File size: 14,730 Bytes
d2902aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
import yaml
import numpy as np
import copy
import torch
from torch import nn
import torch.nn.functional as F

from src.lm import RNNLM

LOG_ZERO = -10000000.0  # Log-zero for CTC

class CTCPrefixScore():
    ''' 
    CTC Prefix score calculator
    An implementation of Algo. 2 in https://www.merl.com/publications/docs/TR2017-190.pdf (Watanabe et. al.)
    Reference (official implementation): https://github.com/espnet/espnet/tree/master/espnet/nets
    '''

    def __init__(self, x):
        self.logzero = -100000000.0
        self.blank = 0
        self.eos = 1
        self.x = x.cpu().numpy()[0]
        self.odim = x.shape[-1]
        self.input_length = len(self.x)

    def init_state(self):
        # 0 = non-blank, 1 = blank
        r = np.full((self.input_length, 2), self.logzero, dtype=np.float32)

        # Accumalate blank at each step
        r[0, 1] = self.x[0, self.blank]
        for i in range(1, self.input_length):
            r[i, 1] = r[i-1, 1] + self.x[i, self.blank]
        return r

    def full_compute(self, g, r_prev):
        '''Given prefix g, return the probability of all possible sequence y (where y = concat(g,c))
           This function computes all possible tokens for c (memory inefficient)'''
        prefix_length = len(g)
        last_char = g[-1] if prefix_length > 0 else 0

        # init. r
        r = np.full((self.input_length, 2, self.odim),
                    self.logzero, dtype=np.float32)

        # start from len(g) because is impossible for CTC to generate |y|>|X|
        start = max(1, prefix_length)

        if prefix_length == 0:
            r[0, 0, :] = self.x[0, :]    # if g = <sos>

        psi = r[start-1, 0, :]

        phi = np.logaddexp(r_prev[:, 0], r_prev[:, 1])

        for t in range(start, self.input_length):
            # prev_blank
            prev_blank = np.full((self.odim), r_prev[t-1, 1], dtype=np.float32)
            # prev_nonblank
            prev_nonblank = np.full(
                (self.odim), r_prev[t-1, 0], dtype=np.float32)
            prev_nonblank[last_char] = self.logzero

            phi = np.logaddexp(prev_nonblank, prev_blank)
            # P(h|current step is non-blank) = [ P(prev. step = y) + P()]*P(c)
            r[t, 0, :] = np.logaddexp(r[t-1, 0, :], phi) + self.x[t, :]
            # P(h|current step is blank) = [P(prev. step is blank) + P(prev. step is non-blank)]*P(now=blank)
            r[t, 1, :] = np.logaddexp(
                r[t-1, 1, :], r[t-1, 0, :]) + self.x[t, self.blank]
            psi = np.logaddexp(psi, phi+self.x[t, :])

        #psi[self.eos] = np.logaddexp(r_prev[-1,0], r_prev[-1,1])
        return psi, np.rollaxis(r, 2)

    def cheap_compute(self, g, r_prev, candidates):
        '''Given prefix g, return the probability of all possible sequence y (where y = concat(g,c))
           This function considers only those tokens in candidates for c (memory efficient)'''
        prefix_length = len(g)
        odim = len(candidates)
        last_char = g[-1] if prefix_length > 0 else 0

        # init. r
        r = np.full((self.input_length, 2, len(candidates)),
                    self.logzero, dtype=np.float32)

        # start from len(g) because is impossible for CTC to generate |y|>|X|
        start = max(1, prefix_length)

        if prefix_length == 0:
            r[0, 0, :] = self.x[0, candidates]    # if g = <sos>

        psi = r[start-1, 0, :]
        # Phi = (prev_nonblank,prev_blank)
        sum_prev = np.logaddexp(r_prev[:, 0], r_prev[:, 1])
        phi = np.repeat(sum_prev[..., None],odim,axis=-1)
        # Handle edge case : last tok of prefix in candidates
        if  prefix_length>0 and last_char in candidates:
            phi[:,candidates.index(last_char)] = r_prev[:,1]

        for t in range(start, self.input_length):
            # prev_blank
            # prev_blank = np.full((odim), r_prev[t-1, 1], dtype=np.float32)
            # prev_nonblank
            # prev_nonblank = np.full((odim), r_prev[t-1, 0], dtype=np.float32)
            # phi = np.logaddexp(prev_nonblank, prev_blank)
            # P(h|current step is non-blank) =  P(prev. step = y)*P(c)
            r[t, 0, :] = np.logaddexp( r[t-1, 0, :], phi[t-1]) + self.x[t, candidates]
            # P(h|current step is blank) = [P(prev. step is blank) + P(prev. step is non-blank)]*P(now=blank)
            r[t, 1, :] = np.logaddexp( r[t-1, 1, :], r[t-1, 0, :]) + self.x[t, self.blank]
            psi = np.logaddexp(psi, phi[t-1,]+self.x[t, candidates])

        # P(end of sentence) = P(g)
        if self.eos in candidates:
            psi[candidates.index(self.eos)] = sum_prev[-1]
        return psi, np.rollaxis(r, 2)

class CTCHypothesis():
    ''' 
        Hypothesis for pure CTC beam search decoding.
        An implementation of Algo. 1 in http://proceedings.mlr.press/v32/graves14.pdf
    '''
    def __init__(self):
        self.y = []
        # All probabilities are computed in log scale
        self.Pr_y_t_blank         = 0.0      # Pr-(y,t-1)  -> Pr-(y,t)
        self.Pr_y_t_nblank        = LOG_ZERO # Pr+(y,t-1)  -> Pr+(y,t)

        self.Pr_y_t_blank_bkup    = 0.0      # Pr-(y,t-1)  -> Pr-(y,t)
        self.Pr_y_t_nblank_bkup   = LOG_ZERO # Pr+(y,t-1)  -> Pr+(y,t)
        
        self.lm_output = None
        self.lm_hidden = None
        self.updated_lm = False
    
    def update_lm(self, output, hidden):
        self.lm_output = output
        self.lm_hidden = hidden
        self.updated_lm = True

    def get_len(self):
        return len(self.y)
    
    def get_string(self):
        # Convert the output sequence from list to string
        return ''.join([str(s) for s in self.y])
    
    def get_score(self):
        return np.logaddexp(self.Pr_y_t_blank, self.Pr_y_t_nblank)
    
    def get_final_score(self):
        if len(self.y) > 0:
            return np.logaddexp(self.Pr_y_t_blank, self.Pr_y_t_nblank) / len(self.y)
        else:
            return np.logaddexp(self.Pr_y_t_blank, self.Pr_y_t_nblank)

    def check_same(self, y_2):
        if len(self.y) != len(y_2):
            return False
        for i in range(len(self.y)):
            if self.y[i] != y_2[i]:
                return False
        return True

    def update_Pr_nblank(self, ctc_y_t):
        # ctc_y_t  : Pr(ye,t|x)
        # Pr+(y,t) = Pr+(y,t-1) * Pr(ye,t|x)
        self.Pr_y_t_nblank += ctc_y_t

    def update_Pr_nblank_prefix(self, ctc_y_t, Pr_y_t_blank_prefix, Pr_y_t_nblank_prefix, Pr_ye_y=None):
        # ctc_y_t  : Pr(ye,t|x)
        lm_prob = Pr_ye_y if Pr_ye_y is not None else 0.0
        if len(self.y) == 0: return
        if len(self.y) == 1:
            Pr_ye_y_prefix = ctc_y_t + lm_prob + np.logaddexp(Pr_y_t_blank_prefix, Pr_y_t_nblank_prefix)
        else:
            # Pr_ye_y : LM Pr(ye|y)
            Pr_ye_y_prefix = ctc_y_t + lm_prob + (Pr_y_t_blank_prefix if self.y[-1] == self.y[-2] \
                                        else np.logaddexp(Pr_y_t_blank_prefix, Pr_y_t_nblank_prefix))
        # Pr+(y,t) = Pr+(y,t) + Pr(ye,y^,t)
        self.Pr_y_t_nblank = np.logaddexp(self.Pr_y_t_nblank, Pr_ye_y_prefix)
    
    def update_Pr_blank(self, ctc_blank_t):
        # Pr-(y,t) = Pr(y,t-1) * Pr(-,t|x)
        self.Pr_y_t_blank = np.logaddexp(self.Pr_y_t_nblank_bkup, self.Pr_y_t_blank_bkup) + ctc_blank_t
    
    def add_token(self, token, ctc_token_t, Pr_k_y=None):
        # Add token to the end of the sequence
        # Update current sequence probability
        lm_prob = Pr_k_y if Pr_k_y is not None else 0.0
        if len(self.y) == 0:
            Pr_y_t_nblank_new = ctc_token_t + lm_prob + np.logaddexp(self.Pr_y_t_blank_bkup, self.Pr_y_t_nblank_bkup)
        else:
            # Pr_k_y : LM Pr(k|y)
            Pr_y_t_nblank_new = ctc_token_t + lm_prob + (self.Pr_y_t_blank_bkup if self.y[-1] == token else \
                                    np.logaddexp(self.Pr_y_t_blank_bkup, self.Pr_y_t_nblank_bkup))

        self.Pr_y_t_blank  = LOG_ZERO
        self.Pr_y_t_nblank = Pr_y_t_nblank_new

        self.Pr_y_t_blank_bkup  = self.Pr_y_t_blank
        self.Pr_y_t_nblank_bkup = self.Pr_y_t_nblank

        self.y.append(token)

    def orig_backup(self):
        self.Pr_y_t_blank_bkup  = self.Pr_y_t_blank
        self.Pr_y_t_nblank_bkup = self.Pr_y_t_nblank

class CTCBeamDecoder(nn.Module):
    ''' Beam decoder for ASR (CTC only) '''
    def __init__(self, asr, vocab_range, beam_size, vocab_candidate,
            lm_path='', lm_config='', lm_weight=0.0, device=None):
        super().__init__()
        # Setup
        self.asr         = asr
        self.vocab_range = vocab_range
        self.beam_size   = beam_size
        self.vocab_cand  = vocab_candidate
        assert self.vocab_cand <= len(self.vocab_range)

        assert self.asr.enable_ctc

        # Setup RNNLM
        self.apply_lm = lm_weight > 0
        self.lm_w = 0
        if self.apply_lm:
            self.device = device
            self.lm_w = lm_weight
            self.lm_path = lm_path
            lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader)
            self.lm = RNNLM(self.asr.vocab_size, **lm_config['model']).to(self.device)
            self.lm.load_state_dict(torch.load(
                self.lm_path, map_location='cpu')['model'])
            self.lm.eval()

    def create_msg(self):
        msg = ['Decode spec| CTC decoding \t| Beam size = {} \t| LM weight = {}'.format(self.beam_size, self.lm_w)]
        return msg

    def forward(self, feat, feat_len):
        # Init.
        assert feat.shape[0] == 1, "Batchsize == 1 is required for beam search"
        
        # Calculate CTC output probability
        ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.asr(feat, feat_len, 10)
        del encode_len, att_output, att_align, dec_state, feat_len
        ctc_output = F.log_softmax(ctc_output[0], dim=-1).cpu().detach().numpy()
        T = len(ctc_output) # ctc_output = Pr(k,t|x) / dim: T x Vocab

        # Best W probable sequences
        B = [CTCHypothesis()]
        if self.apply_lm:
            # 0 == <sos> for RNNLM
            output, hidden = \
                self.lm(torch.zeros((1,1),dtype=torch.long).to(self.device), torch.ones(1,dtype=torch.long).to(self.device), None)
            B[0].update_lm(
                (output).log_softmax(dim=-1).squeeze().cpu().numpy(),
                hidden
            )
        
        start = True
        for t in range(T):
            # greedily ignoring pads at the beginning of the sequence
            if np.argmax(ctc_output[t]) == 0 and start:
                continue
            else:
                start = False
            B_new = []
            for i in range(len(B)): # For y in B
                B_i_new = copy.deepcopy(B[i])
                if B_i_new.get_len() > 0: # If y is not empty
                    if B_i_new.y[-1] == 1:
                        # <eos> = 1 (reached the end)
                        B_new.append(B_i_new)
                        continue
                    B_i_new.update_Pr_nblank(ctc_output[t, B_i_new.y[-1]])
                    # Find the same prefix
                    for j in range(len(B)):
                        if i != j and B[j].check_same(B_i_new.y[:-1]):
                            lm_prob = 0.0
                            if self.apply_lm:
                                lm_prob = self.lm_w * B[j].lm_output[B_i_new.y[-1]]
                            B_i_new.update_Pr_nblank_prefix(ctc_output[t, B_i_new.y[-1]], 
                                B[j].Pr_y_t_blank, B[j].Pr_y_t_nblank, lm_prob)
                            break
                B_i_new.update_Pr_blank(ctc_output[t, 0]) # 0 == <pad>
                if self.apply_lm:
                    lm_hidden = B_i_new.lm_hidden
                    lm_probs = B_i_new.lm_output
                else:
                    lm_hidden = None
                    lm_probs = None
                
                # Sort the next possible output symbol by CTC (and LM) score
                if self.apply_lm:
                    ctc_vocab_cand = sorted(zip(
                        self.vocab_range, ctc_output[t, self.vocab_range] + self.lm_w * lm_probs[self.vocab_range]), 
                        reverse=True, key=lambda x: x[1])
                else:
                    ctc_vocab_cand = sorted(zip(self.vocab_range, ctc_output[t, self.vocab_range]), reverse=True, key=lambda x: x[1])
                # Select top K possible symbols to calculate the probabilities
                for j in range(self.vocab_cand):
                    # <pad>=0, <eos>=1, <unk>=2
                    k = ctc_vocab_cand[j][0]
                    # Pr(k,t|x)
                    hyp_yk = copy.deepcopy(B_i_new)
                    lm_prob = 0.0 if not self.apply_lm else self.lm_w * lm_probs[k]
                    hyp_yk.add_token(k, ctc_output[t, k], lm_prob)
                    hyp_yk.updated_lm = False
                    B_new.append(hyp_yk)
                B_i_new.orig_backup() # Retrieve origin prob. before add_token()
                B_new.append(B_i_new)
            del B
            B = []

            # Remove duplicated sequences by sorting first (O(NlogN))
            B_new = sorted(B_new, key=lambda x: x.get_string())
            B.append(B_new[0]) # First Hyp always unique
            for i in range(1,len(B_new)):
                if B_new[i].check_same(B[-1].y):
                    # Next Hyp is duplicated, pick the higher one
                    if B_new[i].get_score() > B[-1].get_score():
                        B[-1] = B_new[i]
                    continue
                else:
                    # Next Hyp is different, hence valid
                    B.append(B_new[i])
            del B_new

            # Find top W possible sequences
            if t == T - 1:
                B = sorted(B, reverse=True, key=lambda x: x.get_final_score())
            else:
                B = sorted(B, reverse=True, key=lambda x: x.get_score())
            if len(B) > self.beam_size:
                B = B[:self.beam_size]
            
            # Update LM states
            if self.apply_lm and t < T - 1:
                for i in range(len(B)):
                    if B[i].get_len() > 0 and not B[i].updated_lm:
                        output, hidden = \
                            self.lm(B[i].y[-1] * torch.ones((1,1), dtype=torch.long).to(self.device), 
                                torch.ones(1,dtype=torch.long).to(self.device), B[i].lm_hidden)
                        B[i].update_lm(
                            (output).log_softmax(dim=-1).squeeze().cpu().numpy(),
                            hidden
                        )
        
        return [b.y for b in B]