joe123 commited on
Commit
d2902aa
·
1 Parent(s): e51c0d5

Upload 2 files

Browse files
Files changed (2) hide show
  1. ctc.py +352 -0
  2. decode.py +257 -0
ctc.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import numpy as np
3
+ import copy
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+
8
+ from src.lm import RNNLM
9
+
10
+ LOG_ZERO = -10000000.0 # Log-zero for CTC
11
+
12
+ class CTCPrefixScore():
13
+ '''
14
+ CTC Prefix score calculator
15
+ An implementation of Algo. 2 in https://www.merl.com/publications/docs/TR2017-190.pdf (Watanabe et. al.)
16
+ Reference (official implementation): https://github.com/espnet/espnet/tree/master/espnet/nets
17
+ '''
18
+
19
+ def __init__(self, x):
20
+ self.logzero = -100000000.0
21
+ self.blank = 0
22
+ self.eos = 1
23
+ self.x = x.cpu().numpy()[0]
24
+ self.odim = x.shape[-1]
25
+ self.input_length = len(self.x)
26
+
27
+ def init_state(self):
28
+ # 0 = non-blank, 1 = blank
29
+ r = np.full((self.input_length, 2), self.logzero, dtype=np.float32)
30
+
31
+ # Accumalate blank at each step
32
+ r[0, 1] = self.x[0, self.blank]
33
+ for i in range(1, self.input_length):
34
+ r[i, 1] = r[i-1, 1] + self.x[i, self.blank]
35
+ return r
36
+
37
+ def full_compute(self, g, r_prev):
38
+ '''Given prefix g, return the probability of all possible sequence y (where y = concat(g,c))
39
+ This function computes all possible tokens for c (memory inefficient)'''
40
+ prefix_length = len(g)
41
+ last_char = g[-1] if prefix_length > 0 else 0
42
+
43
+ # init. r
44
+ r = np.full((self.input_length, 2, self.odim),
45
+ self.logzero, dtype=np.float32)
46
+
47
+ # start from len(g) because is impossible for CTC to generate |y|>|X|
48
+ start = max(1, prefix_length)
49
+
50
+ if prefix_length == 0:
51
+ r[0, 0, :] = self.x[0, :] # if g = <sos>
52
+
53
+ psi = r[start-1, 0, :]
54
+
55
+ phi = np.logaddexp(r_prev[:, 0], r_prev[:, 1])
56
+
57
+ for t in range(start, self.input_length):
58
+ # prev_blank
59
+ prev_blank = np.full((self.odim), r_prev[t-1, 1], dtype=np.float32)
60
+ # prev_nonblank
61
+ prev_nonblank = np.full(
62
+ (self.odim), r_prev[t-1, 0], dtype=np.float32)
63
+ prev_nonblank[last_char] = self.logzero
64
+
65
+ phi = np.logaddexp(prev_nonblank, prev_blank)
66
+ # P(h|current step is non-blank) = [ P(prev. step = y) + P()]*P(c)
67
+ r[t, 0, :] = np.logaddexp(r[t-1, 0, :], phi) + self.x[t, :]
68
+ # P(h|current step is blank) = [P(prev. step is blank) + P(prev. step is non-blank)]*P(now=blank)
69
+ r[t, 1, :] = np.logaddexp(
70
+ r[t-1, 1, :], r[t-1, 0, :]) + self.x[t, self.blank]
71
+ psi = np.logaddexp(psi, phi+self.x[t, :])
72
+
73
+ #psi[self.eos] = np.logaddexp(r_prev[-1,0], r_prev[-1,1])
74
+ return psi, np.rollaxis(r, 2)
75
+
76
+ def cheap_compute(self, g, r_prev, candidates):
77
+ '''Given prefix g, return the probability of all possible sequence y (where y = concat(g,c))
78
+ This function considers only those tokens in candidates for c (memory efficient)'''
79
+ prefix_length = len(g)
80
+ odim = len(candidates)
81
+ last_char = g[-1] if prefix_length > 0 else 0
82
+
83
+ # init. r
84
+ r = np.full((self.input_length, 2, len(candidates)),
85
+ self.logzero, dtype=np.float32)
86
+
87
+ # start from len(g) because is impossible for CTC to generate |y|>|X|
88
+ start = max(1, prefix_length)
89
+
90
+ if prefix_length == 0:
91
+ r[0, 0, :] = self.x[0, candidates] # if g = <sos>
92
+
93
+ psi = r[start-1, 0, :]
94
+ # Phi = (prev_nonblank,prev_blank)
95
+ sum_prev = np.logaddexp(r_prev[:, 0], r_prev[:, 1])
96
+ phi = np.repeat(sum_prev[..., None],odim,axis=-1)
97
+ # Handle edge case : last tok of prefix in candidates
98
+ if prefix_length>0 and last_char in candidates:
99
+ phi[:,candidates.index(last_char)] = r_prev[:,1]
100
+
101
+ for t in range(start, self.input_length):
102
+ # prev_blank
103
+ # prev_blank = np.full((odim), r_prev[t-1, 1], dtype=np.float32)
104
+ # prev_nonblank
105
+ # prev_nonblank = np.full((odim), r_prev[t-1, 0], dtype=np.float32)
106
+ # phi = np.logaddexp(prev_nonblank, prev_blank)
107
+ # P(h|current step is non-blank) = P(prev. step = y)*P(c)
108
+ r[t, 0, :] = np.logaddexp( r[t-1, 0, :], phi[t-1]) + self.x[t, candidates]
109
+ # P(h|current step is blank) = [P(prev. step is blank) + P(prev. step is non-blank)]*P(now=blank)
110
+ r[t, 1, :] = np.logaddexp( r[t-1, 1, :], r[t-1, 0, :]) + self.x[t, self.blank]
111
+ psi = np.logaddexp(psi, phi[t-1,]+self.x[t, candidates])
112
+
113
+ # P(end of sentence) = P(g)
114
+ if self.eos in candidates:
115
+ psi[candidates.index(self.eos)] = sum_prev[-1]
116
+ return psi, np.rollaxis(r, 2)
117
+
118
+ class CTCHypothesis():
119
+ '''
120
+ Hypothesis for pure CTC beam search decoding.
121
+ An implementation of Algo. 1 in http://proceedings.mlr.press/v32/graves14.pdf
122
+ '''
123
+ def __init__(self):
124
+ self.y = []
125
+ # All probabilities are computed in log scale
126
+ self.Pr_y_t_blank = 0.0 # Pr-(y,t-1) -> Pr-(y,t)
127
+ self.Pr_y_t_nblank = LOG_ZERO # Pr+(y,t-1) -> Pr+(y,t)
128
+
129
+ self.Pr_y_t_blank_bkup = 0.0 # Pr-(y,t-1) -> Pr-(y,t)
130
+ self.Pr_y_t_nblank_bkup = LOG_ZERO # Pr+(y,t-1) -> Pr+(y,t)
131
+
132
+ self.lm_output = None
133
+ self.lm_hidden = None
134
+ self.updated_lm = False
135
+
136
+ def update_lm(self, output, hidden):
137
+ self.lm_output = output
138
+ self.lm_hidden = hidden
139
+ self.updated_lm = True
140
+
141
+ def get_len(self):
142
+ return len(self.y)
143
+
144
+ def get_string(self):
145
+ # Convert the output sequence from list to string
146
+ return ''.join([str(s) for s in self.y])
147
+
148
+ def get_score(self):
149
+ return np.logaddexp(self.Pr_y_t_blank, self.Pr_y_t_nblank)
150
+
151
+ def get_final_score(self):
152
+ if len(self.y) > 0:
153
+ return np.logaddexp(self.Pr_y_t_blank, self.Pr_y_t_nblank) / len(self.y)
154
+ else:
155
+ return np.logaddexp(self.Pr_y_t_blank, self.Pr_y_t_nblank)
156
+
157
+ def check_same(self, y_2):
158
+ if len(self.y) != len(y_2):
159
+ return False
160
+ for i in range(len(self.y)):
161
+ if self.y[i] != y_2[i]:
162
+ return False
163
+ return True
164
+
165
+ def update_Pr_nblank(self, ctc_y_t):
166
+ # ctc_y_t : Pr(ye,t|x)
167
+ # Pr+(y,t) = Pr+(y,t-1) * Pr(ye,t|x)
168
+ self.Pr_y_t_nblank += ctc_y_t
169
+
170
+ def update_Pr_nblank_prefix(self, ctc_y_t, Pr_y_t_blank_prefix, Pr_y_t_nblank_prefix, Pr_ye_y=None):
171
+ # ctc_y_t : Pr(ye,t|x)
172
+ lm_prob = Pr_ye_y if Pr_ye_y is not None else 0.0
173
+ if len(self.y) == 0: return
174
+ if len(self.y) == 1:
175
+ Pr_ye_y_prefix = ctc_y_t + lm_prob + np.logaddexp(Pr_y_t_blank_prefix, Pr_y_t_nblank_prefix)
176
+ else:
177
+ # Pr_ye_y : LM Pr(ye|y)
178
+ Pr_ye_y_prefix = ctc_y_t + lm_prob + (Pr_y_t_blank_prefix if self.y[-1] == self.y[-2] \
179
+ else np.logaddexp(Pr_y_t_blank_prefix, Pr_y_t_nblank_prefix))
180
+ # Pr+(y,t) = Pr+(y,t) + Pr(ye,y^,t)
181
+ self.Pr_y_t_nblank = np.logaddexp(self.Pr_y_t_nblank, Pr_ye_y_prefix)
182
+
183
+ def update_Pr_blank(self, ctc_blank_t):
184
+ # Pr-(y,t) = Pr(y,t-1) * Pr(-,t|x)
185
+ self.Pr_y_t_blank = np.logaddexp(self.Pr_y_t_nblank_bkup, self.Pr_y_t_blank_bkup) + ctc_blank_t
186
+
187
+ def add_token(self, token, ctc_token_t, Pr_k_y=None):
188
+ # Add token to the end of the sequence
189
+ # Update current sequence probability
190
+ lm_prob = Pr_k_y if Pr_k_y is not None else 0.0
191
+ if len(self.y) == 0:
192
+ Pr_y_t_nblank_new = ctc_token_t + lm_prob + np.logaddexp(self.Pr_y_t_blank_bkup, self.Pr_y_t_nblank_bkup)
193
+ else:
194
+ # Pr_k_y : LM Pr(k|y)
195
+ Pr_y_t_nblank_new = ctc_token_t + lm_prob + (self.Pr_y_t_blank_bkup if self.y[-1] == token else \
196
+ np.logaddexp(self.Pr_y_t_blank_bkup, self.Pr_y_t_nblank_bkup))
197
+
198
+ self.Pr_y_t_blank = LOG_ZERO
199
+ self.Pr_y_t_nblank = Pr_y_t_nblank_new
200
+
201
+ self.Pr_y_t_blank_bkup = self.Pr_y_t_blank
202
+ self.Pr_y_t_nblank_bkup = self.Pr_y_t_nblank
203
+
204
+ self.y.append(token)
205
+
206
+ def orig_backup(self):
207
+ self.Pr_y_t_blank_bkup = self.Pr_y_t_blank
208
+ self.Pr_y_t_nblank_bkup = self.Pr_y_t_nblank
209
+
210
+ class CTCBeamDecoder(nn.Module):
211
+ ''' Beam decoder for ASR (CTC only) '''
212
+ def __init__(self, asr, vocab_range, beam_size, vocab_candidate,
213
+ lm_path='', lm_config='', lm_weight=0.0, device=None):
214
+ super().__init__()
215
+ # Setup
216
+ self.asr = asr
217
+ self.vocab_range = vocab_range
218
+ self.beam_size = beam_size
219
+ self.vocab_cand = vocab_candidate
220
+ assert self.vocab_cand <= len(self.vocab_range)
221
+
222
+ assert self.asr.enable_ctc
223
+
224
+ # Setup RNNLM
225
+ self.apply_lm = lm_weight > 0
226
+ self.lm_w = 0
227
+ if self.apply_lm:
228
+ self.device = device
229
+ self.lm_w = lm_weight
230
+ self.lm_path = lm_path
231
+ lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader)
232
+ self.lm = RNNLM(self.asr.vocab_size, **lm_config['model']).to(self.device)
233
+ self.lm.load_state_dict(torch.load(
234
+ self.lm_path, map_location='cpu')['model'])
235
+ self.lm.eval()
236
+
237
+ def create_msg(self):
238
+ msg = ['Decode spec| CTC decoding \t| Beam size = {} \t| LM weight = {}'.format(self.beam_size, self.lm_w)]
239
+ return msg
240
+
241
+ def forward(self, feat, feat_len):
242
+ # Init.
243
+ assert feat.shape[0] == 1, "Batchsize == 1 is required for beam search"
244
+
245
+ # Calculate CTC output probability
246
+ ctc_output, encode_len, att_output, att_align, dec_state = \
247
+ self.asr(feat, feat_len, 10)
248
+ del encode_len, att_output, att_align, dec_state, feat_len
249
+ ctc_output = F.log_softmax(ctc_output[0], dim=-1).cpu().detach().numpy()
250
+ T = len(ctc_output) # ctc_output = Pr(k,t|x) / dim: T x Vocab
251
+
252
+ # Best W probable sequences
253
+ B = [CTCHypothesis()]
254
+ if self.apply_lm:
255
+ # 0 == <sos> for RNNLM
256
+ output, hidden = \
257
+ self.lm(torch.zeros((1,1),dtype=torch.long).to(self.device), torch.ones(1,dtype=torch.long).to(self.device), None)
258
+ B[0].update_lm(
259
+ (output).log_softmax(dim=-1).squeeze().cpu().numpy(),
260
+ hidden
261
+ )
262
+
263
+ start = True
264
+ for t in range(T):
265
+ # greedily ignoring pads at the beginning of the sequence
266
+ if np.argmax(ctc_output[t]) == 0 and start:
267
+ continue
268
+ else:
269
+ start = False
270
+ B_new = []
271
+ for i in range(len(B)): # For y in B
272
+ B_i_new = copy.deepcopy(B[i])
273
+ if B_i_new.get_len() > 0: # If y is not empty
274
+ if B_i_new.y[-1] == 1:
275
+ # <eos> = 1 (reached the end)
276
+ B_new.append(B_i_new)
277
+ continue
278
+ B_i_new.update_Pr_nblank(ctc_output[t, B_i_new.y[-1]])
279
+ # Find the same prefix
280
+ for j in range(len(B)):
281
+ if i != j and B[j].check_same(B_i_new.y[:-1]):
282
+ lm_prob = 0.0
283
+ if self.apply_lm:
284
+ lm_prob = self.lm_w * B[j].lm_output[B_i_new.y[-1]]
285
+ B_i_new.update_Pr_nblank_prefix(ctc_output[t, B_i_new.y[-1]],
286
+ B[j].Pr_y_t_blank, B[j].Pr_y_t_nblank, lm_prob)
287
+ break
288
+ B_i_new.update_Pr_blank(ctc_output[t, 0]) # 0 == <pad>
289
+ if self.apply_lm:
290
+ lm_hidden = B_i_new.lm_hidden
291
+ lm_probs = B_i_new.lm_output
292
+ else:
293
+ lm_hidden = None
294
+ lm_probs = None
295
+
296
+ # Sort the next possible output symbol by CTC (and LM) score
297
+ if self.apply_lm:
298
+ ctc_vocab_cand = sorted(zip(
299
+ self.vocab_range, ctc_output[t, self.vocab_range] + self.lm_w * lm_probs[self.vocab_range]),
300
+ reverse=True, key=lambda x: x[1])
301
+ else:
302
+ ctc_vocab_cand = sorted(zip(self.vocab_range, ctc_output[t, self.vocab_range]), reverse=True, key=lambda x: x[1])
303
+ # Select top K possible symbols to calculate the probabilities
304
+ for j in range(self.vocab_cand):
305
+ # <pad>=0, <eos>=1, <unk>=2
306
+ k = ctc_vocab_cand[j][0]
307
+ # Pr(k,t|x)
308
+ hyp_yk = copy.deepcopy(B_i_new)
309
+ lm_prob = 0.0 if not self.apply_lm else self.lm_w * lm_probs[k]
310
+ hyp_yk.add_token(k, ctc_output[t, k], lm_prob)
311
+ hyp_yk.updated_lm = False
312
+ B_new.append(hyp_yk)
313
+ B_i_new.orig_backup() # Retrieve origin prob. before add_token()
314
+ B_new.append(B_i_new)
315
+ del B
316
+ B = []
317
+
318
+ # Remove duplicated sequences by sorting first (O(NlogN))
319
+ B_new = sorted(B_new, key=lambda x: x.get_string())
320
+ B.append(B_new[0]) # First Hyp always unique
321
+ for i in range(1,len(B_new)):
322
+ if B_new[i].check_same(B[-1].y):
323
+ # Next Hyp is duplicated, pick the higher one
324
+ if B_new[i].get_score() > B[-1].get_score():
325
+ B[-1] = B_new[i]
326
+ continue
327
+ else:
328
+ # Next Hyp is different, hence valid
329
+ B.append(B_new[i])
330
+ del B_new
331
+
332
+ # Find top W possible sequences
333
+ if t == T - 1:
334
+ B = sorted(B, reverse=True, key=lambda x: x.get_final_score())
335
+ else:
336
+ B = sorted(B, reverse=True, key=lambda x: x.get_score())
337
+ if len(B) > self.beam_size:
338
+ B = B[:self.beam_size]
339
+
340
+ # Update LM states
341
+ if self.apply_lm and t < T - 1:
342
+ for i in range(len(B)):
343
+ if B[i].get_len() > 0 and not B[i].updated_lm:
344
+ output, hidden = \
345
+ self.lm(B[i].y[-1] * torch.ones((1,1), dtype=torch.long).to(self.device),
346
+ torch.ones(1,dtype=torch.long).to(self.device), B[i].lm_hidden)
347
+ B[i].update_lm(
348
+ (output).log_softmax(dim=-1).squeeze().cpu().numpy(),
349
+ hidden
350
+ )
351
+
352
+ return [b.y for b in B]
decode.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+
7
+ from src.lm import RNNLM
8
+ from src.ctc import CTCPrefixScore, LOG_ZERO
9
+
10
+ CTC_BEAM_RATIO = 1.5 # DO NOT CHANGE THIS, MAY CAUSE OOM
11
+
12
+
13
+ class BeamDecoder(nn.Module):
14
+ ''' Beam decoder for ASR '''
15
+
16
+ def __init__(self, asr, emb_decoder, beam_size, min_len_ratio, max_len_ratio,
17
+ lm_path='', lm_config='', lm_weight=0.0, ctc_weight=0.0):
18
+ super().__init__()
19
+ # Setup
20
+ self.beam_size = beam_size
21
+ self.min_len_ratio = min_len_ratio
22
+ self.max_len_ratio = max_len_ratio
23
+ self.asr = asr
24
+
25
+ # ToDo : implement pure ctc decode
26
+ assert self.asr.enable_att
27
+
28
+ # Additional decoding modules
29
+ self.apply_ctc = ctc_weight > 0
30
+ if self.apply_ctc:
31
+ assert self.asr.ctc_weight > 0, 'ASR was not trained with CTC decoder'
32
+ self.ctc_w = ctc_weight
33
+ self.ctc_beam_size = int(CTC_BEAM_RATIO * self.beam_size)
34
+
35
+ self.apply_lm = lm_weight > 0
36
+ if self.apply_lm:
37
+ self.lm_w = lm_weight
38
+ self.lm_path = lm_path
39
+ lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader)
40
+ self.lm = RNNLM(self.asr.vocab_size, **lm_config['model'])
41
+ self.lm.load_state_dict(torch.load(
42
+ self.lm_path, map_location='cpu')['model'])
43
+ self.lm.eval()
44
+
45
+ self.apply_emb = emb_decoder is not None
46
+ if self.apply_emb:
47
+ self.emb_decoder = emb_decoder
48
+
49
+ def create_msg(self):
50
+ msg = ['Decode spec| Beam size = {}\t| Min/Max len ratio = {}/{}'.format(
51
+ self.beam_size, self.min_len_ratio, self.max_len_ratio)]
52
+ if self.apply_ctc:
53
+ msg.append(
54
+ ' |Joint CTC decoding enabled \t| weight = {:.2f}\t'.format(self.ctc_w))
55
+ if self.apply_lm:
56
+ msg.append(' |Joint LM decoding enabled \t| weight = {:.2f}\t| src = {}'.format(
57
+ self.lm_w, self.lm_path))
58
+ if self.apply_emb:
59
+ msg.append(' |Joint Emb. decoding enabled \t| weight = {:.2f}'.format(
60
+ self.lm_w, self.emb_decoder.fuse_lambda.mean().cpu().item()))
61
+
62
+ return msg
63
+
64
+ def forward(self, audio_feature, feature_len):
65
+ # Init.
66
+ assert audio_feature.shape[0] == 1, "Batchsize == 1 is required for beam search"
67
+ batch_size = audio_feature.shape[0]
68
+ device = audio_feature.device
69
+ dec_state = self.asr.decoder.init_state(
70
+ batch_size) # Init zero states
71
+ self.asr.attention.reset_mem() # Flush attention mem
72
+ # Max output len set w/ hyper param.
73
+ max_output_len = int(
74
+ np.ceil(feature_len.cpu().item()*self.max_len_ratio))
75
+ # Min output len set w/ hyper param.
76
+ min_output_len = int(
77
+ np.ceil(feature_len.cpu().item()*self.min_len_ratio))
78
+ # Store attention map if location-aware
79
+ store_att = self.asr.attention.mode == 'loc'
80
+ prev_token = torch.zeros(
81
+ (batch_size, 1), dtype=torch.long, device=device) # Start w/ <sos>
82
+ # Cache of beam search
83
+ final_hypothesis, next_top_hypothesis = [], []
84
+ # Incase ctc is disabled
85
+ ctc_state, ctc_prob, candidates, lm_state = None, None, None, None
86
+
87
+ # Encode
88
+ encode_feature, encode_len = self.asr.encoder(
89
+ audio_feature, feature_len)
90
+
91
+ # CTC decoding
92
+ if self.apply_ctc:
93
+ ctc_output = F.log_softmax(
94
+ self.asr.ctc_layer(encode_feature), dim=-1)
95
+ ctc_prefix = CTCPrefixScore(ctc_output)
96
+ ctc_state = ctc_prefix.init_state()
97
+
98
+ # Start w/ empty hypothesis
99
+ prev_top_hypothesis = [Hypothesis(decoder_state=dec_state, output_seq=[],
100
+ output_scores=[], lm_state=None, ctc_prob=0,
101
+ ctc_state=ctc_state, att_map=None)]
102
+ # Attention decoding
103
+ for t in range(max_output_len):
104
+ for hypothesis in prev_top_hypothesis:
105
+ # Resume previous step
106
+ prev_token, prev_dec_state, prev_attn, prev_lm_state, prev_ctc_state = hypothesis.get_state(
107
+ device)
108
+ self.asr.set_state(prev_dec_state, prev_attn)
109
+
110
+ # Normal asr forward
111
+ attn, context = self.asr.attention(
112
+ self.asr.decoder.get_query(), encode_feature, encode_len)
113
+ asr_prev_token = self.asr.pre_embed(prev_token)
114
+ decoder_input = torch.cat([asr_prev_token, context], dim=-1)
115
+ cur_prob, d_state = self.asr.decoder(decoder_input)
116
+
117
+ # Embedding fusion (output shape 1xV)
118
+ if self.apply_emb:
119
+ _, cur_prob = self.emb_decoder( d_state, cur_prob, return_loss=False)
120
+ else:
121
+ cur_prob = F.log_softmax(cur_prob, dim=-1)
122
+
123
+ # Perform CTC prefix scoring on limited candidates (else OOM easily)
124
+ if self.apply_ctc:
125
+ # TODO : Check the performance drop for computing part of candidates only
126
+ _, ctc_candidates = cur_prob.squeeze(0).topk(self.ctc_beam_size, dim=-1)
127
+ candidates = ctc_candidates.cpu().tolist()
128
+ ctc_prob, ctc_state = ctc_prefix.cheap_compute(
129
+ hypothesis.outIndex, prev_ctc_state, candidates)
130
+ # TODO : study why ctc_char (slightly) > 0 sometimes
131
+ ctc_char = torch.FloatTensor(ctc_prob - hypothesis.ctc_prob).to(device)
132
+
133
+ # Combine CTC score and Attention score (HACK: focus on candidates, block others)
134
+ hack_ctc_char = torch.zeros_like(cur_prob).data.fill_(LOG_ZERO)
135
+ for idx, char in enumerate(candidates):
136
+ hack_ctc_char[0, char] = ctc_char[idx]
137
+ cur_prob = (1-self.ctc_w)*cur_prob + self.ctc_w*hack_ctc_char # ctc_char
138
+ cur_prob[0, 0] = LOG_ZERO # Hack to ignore <sos>
139
+
140
+ # Joint RNN-LM decoding
141
+ if self.apply_lm:
142
+ # assuming batch size always 1, resulting 1x1
143
+ lm_input = prev_token.unsqueeze(1)
144
+ lm_output, lm_state = self.lm(
145
+ lm_input, torch.ones([batch_size]), hidden=prev_lm_state)
146
+ # assuming batch size always 1, resulting 1xV
147
+ lm_output = lm_output.squeeze(0)
148
+ cur_prob += self.lm_w*lm_output.log_softmax(dim=-1)
149
+
150
+ # Beam search
151
+ # Note: Ignored batch dim.
152
+ topv, topi = cur_prob.squeeze(0).topk(self.beam_size)
153
+ prev_attn = self.asr.attention.att_layer.prev_att.cpu() if store_att else None
154
+ final, top = hypothesis.addTopk(topi, topv, self.asr.decoder.get_state(), att_map=prev_attn,
155
+ lm_state=lm_state, ctc_state=ctc_state, ctc_prob=ctc_prob,
156
+ ctc_candidates=candidates)
157
+ # Move complete hyps. out
158
+ if final is not None and (t >= min_output_len):
159
+ final_hypothesis.append(final)
160
+ if self.beam_size == 1:
161
+ return final_hypothesis
162
+ next_top_hypothesis.extend(top)
163
+
164
+ # Sort for top N beams
165
+ next_top_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True)
166
+ prev_top_hypothesis = next_top_hypothesis[:self.beam_size]
167
+ next_top_hypothesis = []
168
+
169
+ # Rescore all hyp (finished/unfinished)
170
+ final_hypothesis += prev_top_hypothesis
171
+ final_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True)
172
+
173
+ return final_hypothesis[:self.beam_size]
174
+
175
+
176
+ class Hypothesis:
177
+ '''Hypothesis for beam search decoding.
178
+ Stores the history of label sequence & score
179
+ Stores the previous decoder state, ctc state, ctc score, lm state and attention map (if necessary)'''
180
+
181
+ def __init__(self, decoder_state, output_seq, output_scores, lm_state, ctc_state, ctc_prob, att_map):
182
+ assert len(output_seq) == len(output_scores)
183
+ # attention decoder
184
+ self.decoder_state = decoder_state
185
+ self.att_map = att_map
186
+
187
+ # RNN language model
188
+ if type(lm_state) is tuple:
189
+ self.lm_state = (lm_state[0].cpu(),
190
+ lm_state[1].cpu()) # LSTM state
191
+ elif lm_state is None:
192
+ self.lm_state = None # Init state
193
+ else:
194
+ self.lm_state = lm_state.cpu() # GRU state
195
+
196
+ # Previous outputs
197
+ self.output_seq = output_seq # Prefix, List of list
198
+ self.output_scores = output_scores # Prefix score, list of float
199
+
200
+ # CTC decoding
201
+ self.ctc_state = ctc_state # List of np
202
+ self.ctc_prob = ctc_prob # List of float
203
+
204
+ def avgScore(self):
205
+ '''Return the averaged log probability of hypothesis'''
206
+ assert len(self.output_scores) != 0
207
+ return sum(self.output_scores) / len(self.output_scores)
208
+
209
+ def addTopk(self, topi, topv, decoder_state, att_map=None,
210
+ lm_state=None, ctc_state=None, ctc_prob=0.0, ctc_candidates=[]):
211
+ '''Expand current hypothesis with a given beam size'''
212
+ new_hypothesis = []
213
+ term_score = None
214
+ ctc_s, ctc_p = None, None
215
+ beam_size = topi.shape[-1]
216
+
217
+ for i in range(beam_size):
218
+ # Detect <eos>
219
+ if topi[i].item() == 1:
220
+ term_score = topv[i].cpu()
221
+ continue
222
+
223
+ idxes = self.output_seq[:] # pass by value
224
+ scores = self.output_scores[:] # pass by value
225
+ idxes.append(topi[i].cpu())
226
+ scores.append(topv[i].cpu())
227
+ if ctc_state is not None:
228
+ # ToDo: Handle out-of-candidate case.
229
+ idx = ctc_candidates.index(topi[i].item())
230
+ ctc_s = ctc_state[idx, :, :]
231
+ ctc_p = ctc_prob[idx]
232
+ new_hypothesis.append(Hypothesis(decoder_state,
233
+ output_seq=idxes, output_scores=scores, lm_state=lm_state,
234
+ ctc_state=ctc_s, ctc_prob=ctc_p, att_map=att_map))
235
+ if term_score is not None:
236
+ self.output_seq.append(torch.tensor(1))
237
+ self.output_scores.append(term_score)
238
+ return self, new_hypothesis
239
+ return None, new_hypothesis
240
+
241
+ def get_state(self, device):
242
+ prev_token = self.output_seq[-1] if len(self.output_seq) != 0 else 0
243
+ prev_token = torch.LongTensor([prev_token]).to(device)
244
+ att_map = self.att_map.to(device) if self.att_map is not None else None
245
+ if type(self.lm_state) is tuple:
246
+ lm_state = (self.lm_state[0].to(device),
247
+ self.lm_state[1].to(device)) # LSTM state
248
+ elif self.lm_state is None:
249
+ lm_state = None # Init state
250
+ else:
251
+ lm_state = self.lm_state.to(
252
+ device) # GRU state
253
+ return prev_token, self.decoder_state, att_map, lm_state, self.ctc_state
254
+
255
+ @property
256
+ def outIndex(self):
257
+ return [i.item() for i in self.output_seq]