File size: 4,823 Bytes
864affd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch
from torch import Tensor
from torchaudio._extension import fail_if_no_align

__all__ = []


@fail_if_no_align
def forced_align(

    log_probs: Tensor,

    targets: Tensor,

    input_lengths: Optional[Tensor] = None,

    target_lengths: Optional[Tensor] = None,

    blank: int = 0,

) -> Tuple[Tensor, Tensor]:
    r"""Align a CTC label sequence to an emission.



    .. devices:: CPU CUDA



    .. properties:: TorchScript



    Args:

        log_probs (Tensor): log probability of CTC emission output.

            Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,

            `C` is the number of characters in alphabet including blank.

        targets (Tensor): Target sequence. Tensor of shape `(B, L)`,

            where `L` is the target length.

        input_lengths (Tensor or None, optional):

            Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.

        target_lengths (Tensor or None, optional):

            Lengths of the targets. 1-D Tensor of shape `(B,)`.

        blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)



    Returns:

        Tuple(Tensor, Tensor):

            Tensor: Label for each time step in the alignment path computed using forced alignment.



            Tensor: Log probability scores of the labels for each time step.



    Note:

        The sequence length of `log_probs` must satisfy:





        .. math::

            L_{\text{log\_probs}} \ge L_{\text{label}} + N_{\text{repeat}}



        where :math:`N_{\text{repeat}}` is the number of consecutively repeated tokens.

        For example, in str `"aabbc"`, the number of repeats are `2`.



    Note:

        The current version only supports ``batch_size==1``.

    """
    if blank in targets:
        raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.")
    if torch.max(targets) >= log_probs.shape[-1]:
        raise ValueError("targets values must be less than the CTC dimension")

    if input_lengths is None:
        batch_size, length = log_probs.size(0), log_probs.size(1)
        input_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=log_probs.device)
    if target_lengths is None:
        batch_size, length = targets.size(0), targets.size(1)
        target_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=targets.device)

    # For TorchScript compatibility
    assert input_lengths is not None
    assert target_lengths is not None

    paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
    return paths, scores


@dataclass
class TokenSpan:
    """TokenSpan()

    Token with time stamps and score. Returned by :py:func:`merge_tokens`.

    """

    token: int
    """The token"""
    start: int
    """The start time (inclusive) in emission time axis."""
    end: int
    """The end time (exclusive) in emission time axis."""
    score: float
    """The score of the this token."""

    def __len__(self) -> int:
        """Returns the time span"""
        return self.end - self.start


def merge_tokens(tokens: Tensor, scores: Tensor, blank: int = 0) -> List[TokenSpan]:
    """Removes repeated tokens and blank tokens from the given CTC token sequence.



    Args:

        tokens (Tensor): Alignment tokens (unbatched) returned from :py:func:`forced_align`.

            Shape: `(time, )`.

        scores (Tensor): Alignment scores (unbatched) returned from :py:func:`forced_align`.

            Shape: `(time, )`. When computing the token-size score, the given score is averaged

            across the corresponding time span.



    Returns:

        list of TokenSpan



    Example:

        >>> aligned_tokens, scores = forced_align(emission, targets, input_lengths, target_lengths)

        >>> token_spans = merge_tokens(aligned_tokens[0], scores[0])

    """
    if tokens.ndim != 1 or scores.ndim != 1:
        raise ValueError("`tokens` and `scores` must be 1D Tensor.")
    if len(tokens) != len(scores):
        raise ValueError("`tokens` and `scores` must be the same length.")

    diff = torch.diff(
        tokens, prepend=torch.tensor([-1], device=tokens.device), append=torch.tensor([-1], device=tokens.device)
    )
    changes_wo_blank = torch.nonzero((diff != 0)).squeeze().tolist()
    tokens = tokens.tolist()
    spans = [
        TokenSpan(token=token, start=start, end=end, score=scores[start:end].mean().item())
        for start, end in zip(changes_wo_blank[:-1], changes_wo_blank[1:])
        if (token := tokens[start]) != blank
    ]
    return spans