Spaces:
Sleeping
Sleeping
File size: 5,711 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Wav2letter decoders.
"""
import math
import itertools as it
import torch
from fairseq import utils
from examples.speech_recognition.data.replabels import unpack_replabels
from wav2letter.common import create_word_dict, load_words
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
from wav2letter.decoder import (
CriterionType,
DecoderOptions,
KenLM,
SmearingMode,
Trie,
WordLMDecoder,
)
class W2lDecoder(object):
def __init__(self, args, tgt_dict):
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.nbest = args.nbest
# criterion-specific init
if args.criterion == "ctc_loss":
self.criterion_type = CriterionType.CTC
self.blank = tgt_dict.index("<ctc_blank>")
self.asg_transitions = None
elif args.criterion == "asg_loss":
self.criterion_type = CriterionType.ASG
self.blank = -1
self.asg_transitions = args.asg_transitions
self.max_replabel = args.max_replabel
assert len(self.asg_transitions) == self.vocab_size ** 2
else:
raise RuntimeError(f"unknown criterion: {args.criterion}")
def generate(self, models, sample, prefix_tokens=None):
"""Generate a batch of inferences."""
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions = self.get_emissions(models, encoder_input)
return self.decode(emissions)
def get_emissions(self, models, encoder_input):
"""Run encoder and normalize emissions"""
encoder_out = models[0].encoder(**encoder_input)
if self.criterion_type == CriterionType.CTC:
emissions = models[0].get_normalized_probs(encoder_out, log_probs=True)
elif self.criterion_type == CriterionType.ASG:
emissions = encoder_out["encoder_out"]
return emissions.transpose(0, 1).float().cpu().contiguous()
def get_tokens(self, idxs):
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x >= 0, idxs)
if self.criterion_type == CriterionType.CTC:
idxs = filter(lambda x: x != self.blank, idxs)
elif self.criterion_type == CriterionType.ASG:
idxs = unpack_replabels(list(idxs), self.tgt_dict, self.max_replabel)
return torch.LongTensor(list(idxs))
class W2lViterbiDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
CpuViterbiPath.compute(
B,
T,
N,
get_data_ptr_as_bytes(emissions),
get_data_ptr_as_bytes(transitions),
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),
)
return [
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
for b in range(B)
]
class W2lKenLMDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)
self.silence = tgt_dict.index(args.silence_token)
self.lexicon = load_words(args.lexicon)
self.word_dict = create_word_dict(self.lexicon)
self.unk_word = self.word_dict.get_index("<unk>")
self.lm = KenLM(args.kenlm_model, self.word_dict)
self.trie = Trie(self.vocab_size, self.silence)
start_state = self.lm.start(False)
for word, spellings in self.lexicon.items():
word_idx = self.word_dict.get_index(word)
_, score = self.lm.score(start_state, word_idx)
for spelling in spellings:
spelling_idxs = [tgt_dict.index(token) for token in spelling]
self.trie.insert(spelling_idxs, word_idx, score)
self.trie.smear(SmearingMode.MAX)
self.decoder_opts = DecoderOptions(
args.beam,
args.beam_threshold,
args.lm_weight,
args.word_score,
args.unk_weight,
False,
args.sil_weight,
self.criterion_type,
)
self.decoder = WordLMDecoder(
self.decoder_opts,
self.trie,
self.lm,
self.silence,
self.blank,
self.unk_word,
self.asg_transitions,
)
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
for b in range(B):
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
nbest_results = self.decoder.decode(emissions_ptr, T, N)[: self.nbest]
hypos.append(
[
{"tokens": self.get_tokens(result.tokens), "score": result.score}
for result in nbest_results
]
)
return hypos
|