Spaces:
Running
Running
from abc import ABC, abstractmethod | |
from typing import Dict, List | |
import torch | |
import torchaudio.functional as F | |
from torch import Tensor | |
from torchaudio.functional import TokenSpan | |
class ITokenizer(ABC): | |
def __call__(self, transcript: List[str]) -> List[List[str]]: | |
"""Tokenize the given transcript (list of word) | |
.. note:: | |
The toranscript must be normalized. | |
Args: | |
transcript (list of str): Transcript (list of word). | |
Returns: | |
(list of int): List of token sequences | |
""" | |
class Tokenizer(ITokenizer): | |
def __init__(self, dictionary: Dict[str, int]): | |
self.dictionary = dictionary | |
def __call__(self, transcript: List[str]) -> List[List[int]]: | |
return [[self.dictionary[c] for c in word] for word in transcript] | |
def _align_emission_and_tokens(emission: Tensor, tokens: List[int], blank: int = 0): | |
device = emission.device | |
emission = emission.unsqueeze(0) | |
targets = torch.tensor([tokens], dtype=torch.int32, device=device) | |
aligned_tokens, scores = F.forced_align(emission, targets, blank=blank) | |
scores = scores.exp() # convert back to probability | |
aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension | |
return aligned_tokens, scores | |
class IAligner(ABC): | |
def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]: | |
"""Generate list of time-stamped token sequences | |
Args: | |
emission (Tensor): Sequence of token probability distributions in log-domain. | |
Shape: `(time, tokens)`. | |
tokens (list of integer sequence): Tokenized transcript. | |
Output from :py:class:`torchaudio.pipelines.Wav2Vec2FABundle.Tokenizer`. | |
Returns: | |
(list of TokenSpan sequence): Tokens with time stamps and scores. | |
""" | |
def _unflatten(list_, lengths): | |
assert len(list_) == sum(lengths) | |
i = 0 | |
ret = [] | |
for l in lengths: | |
ret.append(list_[i : i + l]) | |
i += l | |
return ret | |
def _flatten(nested_list): | |
return [item for list_ in nested_list for item in list_] | |
class Aligner(IAligner): | |
def __init__(self, blank): | |
self.blank = blank | |
def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]: | |
if emission.ndim != 2: | |
raise ValueError(f"The input emission must be 2D. Found: {emission.shape}") | |
aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens), self.blank) | |
spans = F.merge_tokens(aligned_tokens, scores) | |
return _unflatten(spans, [len(ts) for ts in tokens]) | |