Spaces:
Running
Running
from dataclasses import dataclass | |
from typing import List, Optional, Tuple | |
import torch | |
from torch import Tensor | |
from torchaudio._extension import fail_if_no_align | |
__all__ = [] | |
def forced_align( | |
log_probs: Tensor, | |
targets: Tensor, | |
input_lengths: Optional[Tensor] = None, | |
target_lengths: Optional[Tensor] = None, | |
blank: int = 0, | |
) -> Tuple[Tensor, Tensor]: | |
r"""Align a CTC label sequence to an emission. | |
.. devices:: CPU CUDA | |
.. properties:: TorchScript | |
Args: | |
log_probs (Tensor): log probability of CTC emission output. | |
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length, | |
`C` is the number of characters in alphabet including blank. | |
targets (Tensor): Target sequence. Tensor of shape `(B, L)`, | |
where `L` is the target length. | |
input_lengths (Tensor or None, optional): | |
Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`. | |
target_lengths (Tensor or None, optional): | |
Lengths of the targets. 1-D Tensor of shape `(B,)`. | |
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0) | |
Returns: | |
Tuple(Tensor, Tensor): | |
Tensor: Label for each time step in the alignment path computed using forced alignment. | |
Tensor: Log probability scores of the labels for each time step. | |
Note: | |
The sequence length of `log_probs` must satisfy: | |
.. math:: | |
L_{\text{log\_probs}} \ge L_{\text{label}} + N_{\text{repeat}} | |
where :math:`N_{\text{repeat}}` is the number of consecutively repeated tokens. | |
For example, in str `"aabbc"`, the number of repeats are `2`. | |
Note: | |
The current version only supports ``batch_size==1``. | |
""" | |
if blank in targets: | |
raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.") | |
if torch.max(targets) >= log_probs.shape[-1]: | |
raise ValueError("targets values must be less than the CTC dimension") | |
if input_lengths is None: | |
batch_size, length = log_probs.size(0), log_probs.size(1) | |
input_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=log_probs.device) | |
if target_lengths is None: | |
batch_size, length = targets.size(0), targets.size(1) | |
target_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=targets.device) | |
# For TorchScript compatibility | |
assert input_lengths is not None | |
assert target_lengths is not None | |
paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank) | |
return paths, scores | |
class TokenSpan: | |
"""TokenSpan() | |
Token with time stamps and score. Returned by :py:func:`merge_tokens`. | |
""" | |
token: int | |
"""The token""" | |
start: int | |
"""The start time (inclusive) in emission time axis.""" | |
end: int | |
"""The end time (exclusive) in emission time axis.""" | |
score: float | |
"""The score of the this token.""" | |
def __len__(self) -> int: | |
"""Returns the time span""" | |
return self.end - self.start | |
def merge_tokens(tokens: Tensor, scores: Tensor, blank: int = 0) -> List[TokenSpan]: | |
"""Removes repeated tokens and blank tokens from the given CTC token sequence. | |
Args: | |
tokens (Tensor): Alignment tokens (unbatched) returned from :py:func:`forced_align`. | |
Shape: `(time, )`. | |
scores (Tensor): Alignment scores (unbatched) returned from :py:func:`forced_align`. | |
Shape: `(time, )`. When computing the token-size score, the given score is averaged | |
across the corresponding time span. | |
Returns: | |
list of TokenSpan | |
Example: | |
>>> aligned_tokens, scores = forced_align(emission, targets, input_lengths, target_lengths) | |
>>> token_spans = merge_tokens(aligned_tokens[0], scores[0]) | |
""" | |
if tokens.ndim != 1 or scores.ndim != 1: | |
raise ValueError("`tokens` and `scores` must be 1D Tensor.") | |
if len(tokens) != len(scores): | |
raise ValueError("`tokens` and `scores` must be the same length.") | |
diff = torch.diff( | |
tokens, prepend=torch.tensor([-1], device=tokens.device), append=torch.tensor([-1], device=tokens.device) | |
) | |
changes_wo_blank = torch.nonzero((diff != 0)).squeeze().tolist() | |
tokens = tokens.tolist() | |
spans = [ | |
TokenSpan(token=token, start=start, end=end, score=scores[start:end].mean().item()) | |
for start, end in zip(changes_wo_blank[:-1], changes_wo_blank[1:]) | |
if (token := tokens[start]) != blank | |
] | |
return spans | |