Spaces:
Running
Running
File size: 13,178 Bytes
864affd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 |
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torchaudio.models import RNNT
__all__ = ["Hypothesis", "RNNTBeamSearch"]
Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float]
Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder,
represented as tuple of (tokens, prediction network output, prediction network state, score).
"""
def _get_hypo_tokens(hypo: Hypothesis) -> List[int]:
return hypo[0]
def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor:
return hypo[1]
def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]:
return hypo[2]
def _get_hypo_score(hypo: Hypothesis) -> float:
return hypo[3]
def _get_hypo_key(hypo: Hypothesis) -> str:
return str(hypo[0])
def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
states: List[List[torch.Tensor]] = []
for i in range(len(_get_hypo_state(hypos[0]))):
batched_state_components: List[torch.Tensor] = []
for j in range(len(_get_hypo_state(hypos[0])[i])):
batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos]))
states.append(batched_state_components)
return states
def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]:
idx_tensor = torch.tensor([idx], device=device)
return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states]
def _default_hypo_sort_key(hypo: Hypothesis) -> float:
return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1)
def _compute_updated_scores(
hypos: List[Hypothesis],
next_token_probs: torch.Tensor,
beam_width: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1)
nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1]
return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token
def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
for i, elem in enumerate(hypo_list):
if _get_hypo_key(hypo) == _get_hypo_key(elem):
del hypo_list[i]
break
class RNNTBeamSearch(torch.nn.Module):
r"""Beam search decoder for RNN-T model.
See Also:
* :class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pretrained model.
Args:
model (RNNT): RNN-T model to use.
blank (int): index of blank token in vocabulary.
temperature (float, optional): temperature to apply to joint network output.
Larger values yield more uniform samples. (Default: 1.0)
hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score
for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns
hypothesis score normalized by token sequence length. (Default: None)
step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100)
"""
def __init__(
self,
model: RNNT,
blank: int,
temperature: float = 1.0,
hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None,
step_max_tokens: int = 100,
) -> None:
super().__init__()
self.model = model
self.blank = blank
self.temperature = temperature
if hypo_sort_key is None:
self.hypo_sort_key = _default_hypo_sort_key
else:
self.hypo_sort_key = hypo_sort_key
self.step_max_tokens = step_max_tokens
def _init_b_hypos(self, device: torch.device) -> List[Hypothesis]:
token = self.blank
state = None
one_tensor = torch.tensor([1], device=device)
pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
init_hypo = (
[token],
pred_out[0].detach(),
pred_state,
0.0,
)
return [init_hypo]
def _gen_next_token_probs(
self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device
) -> torch.Tensor:
one_tensor = torch.tensor([1], device=device)
predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0)
joined_out, _, _ = self.model.join(
enc_out,
one_tensor,
predictor_out,
torch.tensor([1] * len(hypos), device=device),
) # [beam_width, 1, 1, num_tokens]
joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3)
return joined_out[:, 0, 0]
def _gen_b_hypos(
self,
b_hypos: List[Hypothesis],
a_hypos: List[Hypothesis],
next_token_probs: torch.Tensor,
key_to_b_hypo: Dict[str, Hypothesis],
) -> List[Hypothesis]:
for i in range(len(a_hypos)):
h_a = a_hypos[i]
append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1]
if _get_hypo_key(h_a) in key_to_b_hypo:
h_b = key_to_b_hypo[_get_hypo_key(h_a)]
_remove_hypo(h_b, b_hypos)
score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score))
else:
score = float(append_blank_score)
h_b = (
_get_hypo_tokens(h_a),
_get_hypo_predictor_out(h_a),
_get_hypo_state(h_a),
score,
)
b_hypos.append(h_b)
key_to_b_hypo[_get_hypo_key(h_b)] = h_b
_, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort()
return [b_hypos[idx] for idx in sorted_idx]
def _gen_a_hypos(
self,
a_hypos: List[Hypothesis],
b_hypos: List[Hypothesis],
next_token_probs: torch.Tensor,
t: int,
beam_width: int,
device: torch.device,
) -> List[Hypothesis]:
(
nonblank_nbest_scores,
nonblank_nbest_hypo_idx,
nonblank_nbest_token,
) = _compute_updated_scores(a_hypos, next_token_probs, beam_width)
if len(b_hypos) < beam_width:
b_nbest_score = -float("inf")
else:
b_nbest_score = _get_hypo_score(b_hypos[-beam_width])
base_hypos: List[Hypothesis] = []
new_tokens: List[int] = []
new_scores: List[float] = []
for i in range(beam_width):
score = float(nonblank_nbest_scores[i])
if score > b_nbest_score:
a_hypo_idx = int(nonblank_nbest_hypo_idx[i])
base_hypos.append(a_hypos[a_hypo_idx])
new_tokens.append(int(nonblank_nbest_token[i]))
new_scores.append(score)
if base_hypos:
new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device)
else:
new_hypos: List[Hypothesis] = []
return new_hypos
def _gen_new_hypos(
self,
base_hypos: List[Hypothesis],
tokens: List[int],
scores: List[float],
t: int,
device: torch.device,
) -> List[Hypothesis]:
tgt_tokens = torch.tensor([[token] for token in tokens], device=device)
states = _batch_state(base_hypos)
pred_out, _, pred_states = self.model.predict(
tgt_tokens,
torch.tensor([1] * len(base_hypos), device=device),
states,
)
new_hypos: List[Hypothesis] = []
for i, h_a in enumerate(base_hypos):
new_tokens = _get_hypo_tokens(h_a) + [tokens[i]]
new_hypos.append((new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i]))
return new_hypos
def _search(
self,
enc_out: torch.Tensor,
hypo: Optional[List[Hypothesis]],
beam_width: int,
) -> List[Hypothesis]:
n_time_steps = enc_out.shape[1]
device = enc_out.device
a_hypos: List[Hypothesis] = []
b_hypos = self._init_b_hypos(device) if hypo is None else hypo
for t in range(n_time_steps):
a_hypos = b_hypos
b_hypos = torch.jit.annotate(List[Hypothesis], [])
key_to_b_hypo: Dict[str, Hypothesis] = {}
symbols_current_t = 0
while a_hypos:
next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
next_token_probs = next_token_probs.cpu()
b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo)
if symbols_current_t == self.step_max_tokens:
break
a_hypos = self._gen_a_hypos(
a_hypos,
b_hypos,
next_token_probs,
t,
beam_width,
device,
)
if a_hypos:
symbols_current_t += 1
_, sorted_idx = torch.tensor([self.hypo_sort_key(hyp) for hyp in b_hypos]).topk(beam_width)
b_hypos = [b_hypos[idx] for idx in sorted_idx]
return b_hypos
def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int) -> List[Hypothesis]:
r"""Performs beam search for the given input sequence.
T: number of frames;
D: feature dimension of each frame.
Args:
input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
length (torch.Tensor): number of valid frames in input
sequence, with shape () or (1,).
beam_width (int): beam size to use during search.
Returns:
List[Hypothesis]: top-``beam_width`` hypotheses found by beam search.
"""
if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
raise ValueError("input must be of shape (T, D) or (1, T, D)")
if input.dim() == 2:
input = input.unsqueeze(0)
if length.shape != () and length.shape != (1,):
raise ValueError("length must be of shape () or (1,)")
if length.dim() == 0:
length = length.unsqueeze(0)
enc_out, _ = self.model.transcribe(input, length)
return self._search(enc_out, None, beam_width)
@torch.jit.export
def infer(
self,
input: torch.Tensor,
length: torch.Tensor,
beam_width: int,
state: Optional[List[List[torch.Tensor]]] = None,
hypothesis: Optional[List[Hypothesis]] = None,
) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
r"""Performs beam search for the given input sequence in streaming mode.
T: number of frames;
D: feature dimension of each frame.
Args:
input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
length (torch.Tensor): number of valid frames in input
sequence, with shape () or (1,).
beam_width (int): beam size to use during search.
state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
representing transcription network internal state generated in preceding
invocation. (Default: ``None``)
hypothesis (List[Hypothesis] or None): hypotheses from preceding invocation to seed
search with. (Default: ``None``)
Returns:
(List[Hypothesis], List[List[torch.Tensor]]):
List[Hypothesis]
top-``beam_width`` hypotheses found by beam search.
List[List[torch.Tensor]]
list of lists of tensors representing transcription network
internal state generated in current invocation.
"""
if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
raise ValueError("input must be of shape (T, D) or (1, T, D)")
if input.dim() == 2:
input = input.unsqueeze(0)
if length.shape != () and length.shape != (1,):
raise ValueError("length must be of shape () or (1,)")
if length.dim() == 0:
length = length.unsqueeze(0)
enc_out, _, state = self.model.transcribe_streaming(input, length, state)
return self._search(enc_out, hypothesis, beam_width), state
|