File size: 2,796 Bytes
864affd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from abc import ABC, abstractmethod
from typing import Dict, List

import torch
import torchaudio.functional as F
from torch import Tensor
from torchaudio.functional import TokenSpan


class ITokenizer(ABC):
    @abstractmethod
    def __call__(self, transcript: List[str]) -> List[List[str]]:
        """Tokenize the given transcript (list of word)



        .. note::



           The toranscript must be normalized.



        Args:

            transcript (list of str): Transcript (list of word).



        Returns:

            (list of int): List of token sequences

        """


class Tokenizer(ITokenizer):
    def __init__(self, dictionary: Dict[str, int]):
        self.dictionary = dictionary

    def __call__(self, transcript: List[str]) -> List[List[int]]:
        return [[self.dictionary[c] for c in word] for word in transcript]


def _align_emission_and_tokens(emission: Tensor, tokens: List[int], blank: int = 0):
    device = emission.device
    emission = emission.unsqueeze(0)
    targets = torch.tensor([tokens], dtype=torch.int32, device=device)

    aligned_tokens, scores = F.forced_align(emission, targets, blank=blank)

    scores = scores.exp()  # convert back to probability
    aligned_tokens, scores = aligned_tokens[0], scores[0]  # remove batch dimension
    return aligned_tokens, scores


class IAligner(ABC):
    @abstractmethod
    def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
        """Generate list of time-stamped token sequences



        Args:

            emission (Tensor): Sequence of token probability distributions in log-domain.

                Shape: `(time, tokens)`.

            tokens (list of integer sequence): Tokenized transcript.

                Output from :py:class:`torchaudio.pipelines.Wav2Vec2FABundle.Tokenizer`.



        Returns:

            (list of TokenSpan sequence): Tokens with time stamps and scores.

        """


def _unflatten(list_, lengths):
    assert len(list_) == sum(lengths)
    i = 0
    ret = []
    for l in lengths:
        ret.append(list_[i : i + l])
        i += l
    return ret


def _flatten(nested_list):
    return [item for list_ in nested_list for item in list_]


class Aligner(IAligner):
    def __init__(self, blank):
        self.blank = blank

    def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
        if emission.ndim != 2:
            raise ValueError(f"The input emission must be 2D. Found: {emission.shape}")

        aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens), self.blank)
        spans = F.merge_tokens(aligned_tokens, scores)
        return _unflatten(spans, [len(ts) for ts in tokens])