|
import yaml |
|
import numpy as np |
|
import copy |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
from src.lm import RNNLM |
|
|
|
LOG_ZERO = -10000000.0 |
|
|
|
class CTCPrefixScore(): |
|
''' |
|
CTC Prefix score calculator |
|
An implementation of Algo. 2 in https://www.merl.com/publications/docs/TR2017-190.pdf (Watanabe et. al.) |
|
Reference (official implementation): https://github.com/espnet/espnet/tree/master/espnet/nets |
|
''' |
|
|
|
def __init__(self, x): |
|
self.logzero = -100000000.0 |
|
self.blank = 0 |
|
self.eos = 1 |
|
self.x = x.cpu().numpy()[0] |
|
self.odim = x.shape[-1] |
|
self.input_length = len(self.x) |
|
|
|
def init_state(self): |
|
|
|
r = np.full((self.input_length, 2), self.logzero, dtype=np.float32) |
|
|
|
|
|
r[0, 1] = self.x[0, self.blank] |
|
for i in range(1, self.input_length): |
|
r[i, 1] = r[i-1, 1] + self.x[i, self.blank] |
|
return r |
|
|
|
def full_compute(self, g, r_prev): |
|
'''Given prefix g, return the probability of all possible sequence y (where y = concat(g,c)) |
|
This function computes all possible tokens for c (memory inefficient)''' |
|
prefix_length = len(g) |
|
last_char = g[-1] if prefix_length > 0 else 0 |
|
|
|
|
|
r = np.full((self.input_length, 2, self.odim), |
|
self.logzero, dtype=np.float32) |
|
|
|
|
|
start = max(1, prefix_length) |
|
|
|
if prefix_length == 0: |
|
r[0, 0, :] = self.x[0, :] |
|
|
|
psi = r[start-1, 0, :] |
|
|
|
phi = np.logaddexp(r_prev[:, 0], r_prev[:, 1]) |
|
|
|
for t in range(start, self.input_length): |
|
|
|
prev_blank = np.full((self.odim), r_prev[t-1, 1], dtype=np.float32) |
|
|
|
prev_nonblank = np.full( |
|
(self.odim), r_prev[t-1, 0], dtype=np.float32) |
|
prev_nonblank[last_char] = self.logzero |
|
|
|
phi = np.logaddexp(prev_nonblank, prev_blank) |
|
|
|
r[t, 0, :] = np.logaddexp(r[t-1, 0, :], phi) + self.x[t, :] |
|
|
|
r[t, 1, :] = np.logaddexp( |
|
r[t-1, 1, :], r[t-1, 0, :]) + self.x[t, self.blank] |
|
psi = np.logaddexp(psi, phi+self.x[t, :]) |
|
|
|
|
|
return psi, np.rollaxis(r, 2) |
|
|
|
def cheap_compute(self, g, r_prev, candidates): |
|
'''Given prefix g, return the probability of all possible sequence y (where y = concat(g,c)) |
|
This function considers only those tokens in candidates for c (memory efficient)''' |
|
prefix_length = len(g) |
|
odim = len(candidates) |
|
last_char = g[-1] if prefix_length > 0 else 0 |
|
|
|
|
|
r = np.full((self.input_length, 2, len(candidates)), |
|
self.logzero, dtype=np.float32) |
|
|
|
|
|
start = max(1, prefix_length) |
|
|
|
if prefix_length == 0: |
|
r[0, 0, :] = self.x[0, candidates] |
|
|
|
psi = r[start-1, 0, :] |
|
|
|
sum_prev = np.logaddexp(r_prev[:, 0], r_prev[:, 1]) |
|
phi = np.repeat(sum_prev[..., None],odim,axis=-1) |
|
|
|
if prefix_length>0 and last_char in candidates: |
|
phi[:,candidates.index(last_char)] = r_prev[:,1] |
|
|
|
for t in range(start, self.input_length): |
|
|
|
|
|
|
|
|
|
|
|
|
|
r[t, 0, :] = np.logaddexp( r[t-1, 0, :], phi[t-1]) + self.x[t, candidates] |
|
|
|
r[t, 1, :] = np.logaddexp( r[t-1, 1, :], r[t-1, 0, :]) + self.x[t, self.blank] |
|
psi = np.logaddexp(psi, phi[t-1,]+self.x[t, candidates]) |
|
|
|
|
|
if self.eos in candidates: |
|
psi[candidates.index(self.eos)] = sum_prev[-1] |
|
return psi, np.rollaxis(r, 2) |
|
|
|
class CTCHypothesis(): |
|
''' |
|
Hypothesis for pure CTC beam search decoding. |
|
An implementation of Algo. 1 in http://proceedings.mlr.press/v32/graves14.pdf |
|
''' |
|
def __init__(self): |
|
self.y = [] |
|
|
|
self.Pr_y_t_blank = 0.0 |
|
self.Pr_y_t_nblank = LOG_ZERO |
|
|
|
self.Pr_y_t_blank_bkup = 0.0 |
|
self.Pr_y_t_nblank_bkup = LOG_ZERO |
|
|
|
self.lm_output = None |
|
self.lm_hidden = None |
|
self.updated_lm = False |
|
|
|
def update_lm(self, output, hidden): |
|
self.lm_output = output |
|
self.lm_hidden = hidden |
|
self.updated_lm = True |
|
|
|
def get_len(self): |
|
return len(self.y) |
|
|
|
def get_string(self): |
|
|
|
return ''.join([str(s) for s in self.y]) |
|
|
|
def get_score(self): |
|
return np.logaddexp(self.Pr_y_t_blank, self.Pr_y_t_nblank) |
|
|
|
def get_final_score(self): |
|
if len(self.y) > 0: |
|
return np.logaddexp(self.Pr_y_t_blank, self.Pr_y_t_nblank) / len(self.y) |
|
else: |
|
return np.logaddexp(self.Pr_y_t_blank, self.Pr_y_t_nblank) |
|
|
|
def check_same(self, y_2): |
|
if len(self.y) != len(y_2): |
|
return False |
|
for i in range(len(self.y)): |
|
if self.y[i] != y_2[i]: |
|
return False |
|
return True |
|
|
|
def update_Pr_nblank(self, ctc_y_t): |
|
|
|
|
|
self.Pr_y_t_nblank += ctc_y_t |
|
|
|
def update_Pr_nblank_prefix(self, ctc_y_t, Pr_y_t_blank_prefix, Pr_y_t_nblank_prefix, Pr_ye_y=None): |
|
|
|
lm_prob = Pr_ye_y if Pr_ye_y is not None else 0.0 |
|
if len(self.y) == 0: return |
|
if len(self.y) == 1: |
|
Pr_ye_y_prefix = ctc_y_t + lm_prob + np.logaddexp(Pr_y_t_blank_prefix, Pr_y_t_nblank_prefix) |
|
else: |
|
|
|
Pr_ye_y_prefix = ctc_y_t + lm_prob + (Pr_y_t_blank_prefix if self.y[-1] == self.y[-2] \ |
|
else np.logaddexp(Pr_y_t_blank_prefix, Pr_y_t_nblank_prefix)) |
|
|
|
self.Pr_y_t_nblank = np.logaddexp(self.Pr_y_t_nblank, Pr_ye_y_prefix) |
|
|
|
def update_Pr_blank(self, ctc_blank_t): |
|
|
|
self.Pr_y_t_blank = np.logaddexp(self.Pr_y_t_nblank_bkup, self.Pr_y_t_blank_bkup) + ctc_blank_t |
|
|
|
def add_token(self, token, ctc_token_t, Pr_k_y=None): |
|
|
|
|
|
lm_prob = Pr_k_y if Pr_k_y is not None else 0.0 |
|
if len(self.y) == 0: |
|
Pr_y_t_nblank_new = ctc_token_t + lm_prob + np.logaddexp(self.Pr_y_t_blank_bkup, self.Pr_y_t_nblank_bkup) |
|
else: |
|
|
|
Pr_y_t_nblank_new = ctc_token_t + lm_prob + (self.Pr_y_t_blank_bkup if self.y[-1] == token else \ |
|
np.logaddexp(self.Pr_y_t_blank_bkup, self.Pr_y_t_nblank_bkup)) |
|
|
|
self.Pr_y_t_blank = LOG_ZERO |
|
self.Pr_y_t_nblank = Pr_y_t_nblank_new |
|
|
|
self.Pr_y_t_blank_bkup = self.Pr_y_t_blank |
|
self.Pr_y_t_nblank_bkup = self.Pr_y_t_nblank |
|
|
|
self.y.append(token) |
|
|
|
def orig_backup(self): |
|
self.Pr_y_t_blank_bkup = self.Pr_y_t_blank |
|
self.Pr_y_t_nblank_bkup = self.Pr_y_t_nblank |
|
|
|
class CTCBeamDecoder(nn.Module): |
|
''' Beam decoder for ASR (CTC only) ''' |
|
def __init__(self, asr, vocab_range, beam_size, vocab_candidate, |
|
lm_path='', lm_config='', lm_weight=0.0, device=None): |
|
super().__init__() |
|
|
|
self.asr = asr |
|
self.vocab_range = vocab_range |
|
self.beam_size = beam_size |
|
self.vocab_cand = vocab_candidate |
|
assert self.vocab_cand <= len(self.vocab_range) |
|
|
|
assert self.asr.enable_ctc |
|
|
|
|
|
self.apply_lm = lm_weight > 0 |
|
self.lm_w = 0 |
|
if self.apply_lm: |
|
self.device = device |
|
self.lm_w = lm_weight |
|
self.lm_path = lm_path |
|
lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader) |
|
self.lm = RNNLM(self.asr.vocab_size, **lm_config['model']).to(self.device) |
|
self.lm.load_state_dict(torch.load( |
|
self.lm_path, map_location='cpu')['model']) |
|
self.lm.eval() |
|
|
|
def create_msg(self): |
|
msg = ['Decode spec| CTC decoding \t| Beam size = {} \t| LM weight = {}'.format(self.beam_size, self.lm_w)] |
|
return msg |
|
|
|
def forward(self, feat, feat_len): |
|
|
|
assert feat.shape[0] == 1, "Batchsize == 1 is required for beam search" |
|
|
|
|
|
ctc_output, encode_len, att_output, att_align, dec_state = \ |
|
self.asr(feat, feat_len, 10) |
|
del encode_len, att_output, att_align, dec_state, feat_len |
|
ctc_output = F.log_softmax(ctc_output[0], dim=-1).cpu().detach().numpy() |
|
T = len(ctc_output) |
|
|
|
|
|
B = [CTCHypothesis()] |
|
if self.apply_lm: |
|
|
|
output, hidden = \ |
|
self.lm(torch.zeros((1,1),dtype=torch.long).to(self.device), torch.ones(1,dtype=torch.long).to(self.device), None) |
|
B[0].update_lm( |
|
(output).log_softmax(dim=-1).squeeze().cpu().numpy(), |
|
hidden |
|
) |
|
|
|
start = True |
|
for t in range(T): |
|
|
|
if np.argmax(ctc_output[t]) == 0 and start: |
|
continue |
|
else: |
|
start = False |
|
B_new = [] |
|
for i in range(len(B)): |
|
B_i_new = copy.deepcopy(B[i]) |
|
if B_i_new.get_len() > 0: |
|
if B_i_new.y[-1] == 1: |
|
|
|
B_new.append(B_i_new) |
|
continue |
|
B_i_new.update_Pr_nblank(ctc_output[t, B_i_new.y[-1]]) |
|
|
|
for j in range(len(B)): |
|
if i != j and B[j].check_same(B_i_new.y[:-1]): |
|
lm_prob = 0.0 |
|
if self.apply_lm: |
|
lm_prob = self.lm_w * B[j].lm_output[B_i_new.y[-1]] |
|
B_i_new.update_Pr_nblank_prefix(ctc_output[t, B_i_new.y[-1]], |
|
B[j].Pr_y_t_blank, B[j].Pr_y_t_nblank, lm_prob) |
|
break |
|
B_i_new.update_Pr_blank(ctc_output[t, 0]) |
|
if self.apply_lm: |
|
lm_hidden = B_i_new.lm_hidden |
|
lm_probs = B_i_new.lm_output |
|
else: |
|
lm_hidden = None |
|
lm_probs = None |
|
|
|
|
|
if self.apply_lm: |
|
ctc_vocab_cand = sorted(zip( |
|
self.vocab_range, ctc_output[t, self.vocab_range] + self.lm_w * lm_probs[self.vocab_range]), |
|
reverse=True, key=lambda x: x[1]) |
|
else: |
|
ctc_vocab_cand = sorted(zip(self.vocab_range, ctc_output[t, self.vocab_range]), reverse=True, key=lambda x: x[1]) |
|
|
|
for j in range(self.vocab_cand): |
|
|
|
k = ctc_vocab_cand[j][0] |
|
|
|
hyp_yk = copy.deepcopy(B_i_new) |
|
lm_prob = 0.0 if not self.apply_lm else self.lm_w * lm_probs[k] |
|
hyp_yk.add_token(k, ctc_output[t, k], lm_prob) |
|
hyp_yk.updated_lm = False |
|
B_new.append(hyp_yk) |
|
B_i_new.orig_backup() |
|
B_new.append(B_i_new) |
|
del B |
|
B = [] |
|
|
|
|
|
B_new = sorted(B_new, key=lambda x: x.get_string()) |
|
B.append(B_new[0]) |
|
for i in range(1,len(B_new)): |
|
if B_new[i].check_same(B[-1].y): |
|
|
|
if B_new[i].get_score() > B[-1].get_score(): |
|
B[-1] = B_new[i] |
|
continue |
|
else: |
|
|
|
B.append(B_new[i]) |
|
del B_new |
|
|
|
|
|
if t == T - 1: |
|
B = sorted(B, reverse=True, key=lambda x: x.get_final_score()) |
|
else: |
|
B = sorted(B, reverse=True, key=lambda x: x.get_score()) |
|
if len(B) > self.beam_size: |
|
B = B[:self.beam_size] |
|
|
|
|
|
if self.apply_lm and t < T - 1: |
|
for i in range(len(B)): |
|
if B[i].get_len() > 0 and not B[i].updated_lm: |
|
output, hidden = \ |
|
self.lm(B[i].y[-1] * torch.ones((1,1), dtype=torch.long).to(self.device), |
|
torch.ones(1,dtype=torch.long).to(self.device), B[i].lm_hidden) |
|
B[i].update_lm( |
|
(output).log_softmax(dim=-1).squeeze().cpu().numpy(), |
|
hidden |
|
) |
|
|
|
return [b.y for b in B] |
|
|