Spaces:
Paused
Paused
from typing import Callable, Dict, List, Optional, Tuple | |
import torch | |
from torchaudio.models import RNNT | |
__all__ = ["Hypothesis", "RNNTBeamSearch"] | |
Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float] | |
Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder, | |
represented as tuple of (tokens, prediction network output, prediction network state, score). | |
""" | |
def _get_hypo_tokens(hypo: Hypothesis) -> List[int]: | |
return hypo[0] | |
def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor: | |
return hypo[1] | |
def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]: | |
return hypo[2] | |
def _get_hypo_score(hypo: Hypothesis) -> float: | |
return hypo[3] | |
def _get_hypo_key(hypo: Hypothesis) -> str: | |
return str(hypo[0]) | |
def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]: | |
states: List[List[torch.Tensor]] = [] | |
for i in range(len(_get_hypo_state(hypos[0]))): | |
batched_state_components: List[torch.Tensor] = [] | |
for j in range(len(_get_hypo_state(hypos[0])[i])): | |
batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos])) | |
states.append(batched_state_components) | |
return states | |
def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]: | |
idx_tensor = torch.tensor([idx], device=device) | |
return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states] | |
def _default_hypo_sort_key(hypo: Hypothesis) -> float: | |
return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1) | |
def _compute_updated_scores( | |
hypos: List[Hypothesis], | |
next_token_probs: torch.Tensor, | |
beam_width: int, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1) | |
nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1] | |
nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width) | |
nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc") | |
nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1] | |
return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token | |
def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None: | |
for i, elem in enumerate(hypo_list): | |
if _get_hypo_key(hypo) == _get_hypo_key(elem): | |
del hypo_list[i] | |
break | |
class RNNTBeamSearch(torch.nn.Module): | |
r"""Beam search decoder for RNN-T model. | |
See Also: | |
* :class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pretrained model. | |
Args: | |
model (RNNT): RNN-T model to use. | |
blank (int): index of blank token in vocabulary. | |
temperature (float, optional): temperature to apply to joint network output. | |
Larger values yield more uniform samples. (Default: 1.0) | |
hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score | |
for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns | |
hypothesis score normalized by token sequence length. (Default: None) | |
step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100) | |
""" | |
def __init__( | |
self, | |
model: RNNT, | |
blank: int, | |
temperature: float = 1.0, | |
hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None, | |
step_max_tokens: int = 100, | |
) -> None: | |
super().__init__() | |
self.model = model | |
self.blank = blank | |
self.temperature = temperature | |
if hypo_sort_key is None: | |
self.hypo_sort_key = _default_hypo_sort_key | |
else: | |
self.hypo_sort_key = hypo_sort_key | |
self.step_max_tokens = step_max_tokens | |
def _init_b_hypos(self, device: torch.device) -> List[Hypothesis]: | |
token = self.blank | |
state = None | |
one_tensor = torch.tensor([1], device=device) | |
pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state) | |
init_hypo = ( | |
[token], | |
pred_out[0].detach(), | |
pred_state, | |
0.0, | |
) | |
return [init_hypo] | |
def _gen_next_token_probs( | |
self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device | |
) -> torch.Tensor: | |
one_tensor = torch.tensor([1], device=device) | |
predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0) | |
joined_out, _, _ = self.model.join( | |
enc_out, | |
one_tensor, | |
predictor_out, | |
torch.tensor([1] * len(hypos), device=device), | |
) # [beam_width, 1, 1, num_tokens] | |
joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3) | |
return joined_out[:, 0, 0] | |
def _gen_b_hypos( | |
self, | |
b_hypos: List[Hypothesis], | |
a_hypos: List[Hypothesis], | |
next_token_probs: torch.Tensor, | |
key_to_b_hypo: Dict[str, Hypothesis], | |
) -> List[Hypothesis]: | |
for i in range(len(a_hypos)): | |
h_a = a_hypos[i] | |
append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1] | |
if _get_hypo_key(h_a) in key_to_b_hypo: | |
h_b = key_to_b_hypo[_get_hypo_key(h_a)] | |
_remove_hypo(h_b, b_hypos) | |
score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score)) | |
else: | |
score = float(append_blank_score) | |
h_b = ( | |
_get_hypo_tokens(h_a), | |
_get_hypo_predictor_out(h_a), | |
_get_hypo_state(h_a), | |
score, | |
) | |
b_hypos.append(h_b) | |
key_to_b_hypo[_get_hypo_key(h_b)] = h_b | |
_, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort() | |
return [b_hypos[idx] for idx in sorted_idx] | |
def _gen_a_hypos( | |
self, | |
a_hypos: List[Hypothesis], | |
b_hypos: List[Hypothesis], | |
next_token_probs: torch.Tensor, | |
t: int, | |
beam_width: int, | |
device: torch.device, | |
) -> List[Hypothesis]: | |
( | |
nonblank_nbest_scores, | |
nonblank_nbest_hypo_idx, | |
nonblank_nbest_token, | |
) = _compute_updated_scores(a_hypos, next_token_probs, beam_width) | |
if len(b_hypos) < beam_width: | |
b_nbest_score = -float("inf") | |
else: | |
b_nbest_score = _get_hypo_score(b_hypos[-beam_width]) | |
base_hypos: List[Hypothesis] = [] | |
new_tokens: List[int] = [] | |
new_scores: List[float] = [] | |
for i in range(beam_width): | |
score = float(nonblank_nbest_scores[i]) | |
if score > b_nbest_score: | |
a_hypo_idx = int(nonblank_nbest_hypo_idx[i]) | |
base_hypos.append(a_hypos[a_hypo_idx]) | |
new_tokens.append(int(nonblank_nbest_token[i])) | |
new_scores.append(score) | |
if base_hypos: | |
new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device) | |
else: | |
new_hypos: List[Hypothesis] = [] | |
return new_hypos | |
def _gen_new_hypos( | |
self, | |
base_hypos: List[Hypothesis], | |
tokens: List[int], | |
scores: List[float], | |
t: int, | |
device: torch.device, | |
) -> List[Hypothesis]: | |
tgt_tokens = torch.tensor([[token] for token in tokens], device=device) | |
states = _batch_state(base_hypos) | |
pred_out, _, pred_states = self.model.predict( | |
tgt_tokens, | |
torch.tensor([1] * len(base_hypos), device=device), | |
states, | |
) | |
new_hypos: List[Hypothesis] = [] | |
for i, h_a in enumerate(base_hypos): | |
new_tokens = _get_hypo_tokens(h_a) + [tokens[i]] | |
new_hypos.append((new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i])) | |
return new_hypos | |
def _search( | |
self, | |
enc_out: torch.Tensor, | |
hypo: Optional[List[Hypothesis]], | |
beam_width: int, | |
) -> List[Hypothesis]: | |
n_time_steps = enc_out.shape[1] | |
device = enc_out.device | |
a_hypos: List[Hypothesis] = [] | |
b_hypos = self._init_b_hypos(device) if hypo is None else hypo | |
for t in range(n_time_steps): | |
a_hypos = b_hypos | |
b_hypos = torch.jit.annotate(List[Hypothesis], []) | |
key_to_b_hypo: Dict[str, Hypothesis] = {} | |
symbols_current_t = 0 | |
while a_hypos: | |
next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device) | |
next_token_probs = next_token_probs.cpu() | |
b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo) | |
if symbols_current_t == self.step_max_tokens: | |
break | |
a_hypos = self._gen_a_hypos( | |
a_hypos, | |
b_hypos, | |
next_token_probs, | |
t, | |
beam_width, | |
device, | |
) | |
if a_hypos: | |
symbols_current_t += 1 | |
_, sorted_idx = torch.tensor([self.hypo_sort_key(hyp) for hyp in b_hypos]).topk(beam_width) | |
b_hypos = [b_hypos[idx] for idx in sorted_idx] | |
return b_hypos | |
def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int) -> List[Hypothesis]: | |
r"""Performs beam search for the given input sequence. | |
T: number of frames; | |
D: feature dimension of each frame. | |
Args: | |
input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D). | |
length (torch.Tensor): number of valid frames in input | |
sequence, with shape () or (1,). | |
beam_width (int): beam size to use during search. | |
Returns: | |
List[Hypothesis]: top-``beam_width`` hypotheses found by beam search. | |
""" | |
if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1): | |
raise ValueError("input must be of shape (T, D) or (1, T, D)") | |
if input.dim() == 2: | |
input = input.unsqueeze(0) | |
if length.shape != () and length.shape != (1,): | |
raise ValueError("length must be of shape () or (1,)") | |
if length.dim() == 0: | |
length = length.unsqueeze(0) | |
enc_out, _ = self.model.transcribe(input, length) | |
return self._search(enc_out, None, beam_width) | |
def infer( | |
self, | |
input: torch.Tensor, | |
length: torch.Tensor, | |
beam_width: int, | |
state: Optional[List[List[torch.Tensor]]] = None, | |
hypothesis: Optional[List[Hypothesis]] = None, | |
) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]: | |
r"""Performs beam search for the given input sequence in streaming mode. | |
T: number of frames; | |
D: feature dimension of each frame. | |
Args: | |
input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D). | |
length (torch.Tensor): number of valid frames in input | |
sequence, with shape () or (1,). | |
beam_width (int): beam size to use during search. | |
state (List[List[torch.Tensor]] or None, optional): list of lists of tensors | |
representing transcription network internal state generated in preceding | |
invocation. (Default: ``None``) | |
hypothesis (List[Hypothesis] or None): hypotheses from preceding invocation to seed | |
search with. (Default: ``None``) | |
Returns: | |
(List[Hypothesis], List[List[torch.Tensor]]): | |
List[Hypothesis] | |
top-``beam_width`` hypotheses found by beam search. | |
List[List[torch.Tensor]] | |
list of lists of tensors representing transcription network | |
internal state generated in current invocation. | |
""" | |
if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1): | |
raise ValueError("input must be of shape (T, D) or (1, T, D)") | |
if input.dim() == 2: | |
input = input.unsqueeze(0) | |
if length.shape != () and length.shape != (1,): | |
raise ValueError("length must be of shape () or (1,)") | |
if length.dim() == 0: | |
length = length.unsqueeze(0) | |
enc_out, _, state = self.model.transcribe_streaming(input, length, state) | |
return self._search(enc_out, hypothesis, beam_width), state | |