Spaces:
Running
Running
from typing import List, Optional, Tuple | |
import torch | |
from torch import nn, Tensor | |
from torchaudio._internal import load_state_dict_from_url | |
from torchaudio.models import wav2vec2_model, Wav2Vec2Model, wavlm_model | |
def _get_model(type_, params): | |
factories = { | |
"Wav2Vec2": wav2vec2_model, | |
"WavLM": wavlm_model, | |
} | |
if type_ not in factories: | |
raise ValueError(f"Supported model types are {tuple(factories.keys())}. Found: {type_}") | |
factory = factories[type_] | |
return factory(**params) | |
class _Wav2Vec2Model(nn.Module): | |
"""Wrapper class for :py:class:`~torchaudio.models.Wav2Vec2Model`. | |
This is used for layer normalization at the input | |
""" | |
def __init__(self, model: Wav2Vec2Model, normalize_waveform: bool, apply_log_softmax: bool, append_star: bool): | |
super().__init__() | |
self.model = model | |
self.normalize_waveform = normalize_waveform | |
self.apply_log_softmax = apply_log_softmax | |
self.append_star = append_star | |
def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: | |
if self.normalize_waveform: | |
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape) | |
output, output_lengths = self.model(waveforms, lengths) | |
if self.apply_log_softmax: | |
output = torch.nn.functional.log_softmax(output, dim=-1) | |
if self.append_star: | |
star_dim = torch.zeros((1, output.size(1), 1), dtype=output.dtype, device=output.device) | |
output = torch.cat((output, star_dim), dim=-1) | |
return output, output_lengths | |
def extract_features( | |
self, | |
waveforms: Tensor, | |
lengths: Optional[Tensor] = None, | |
num_layers: Optional[int] = None, | |
) -> Tuple[List[Tensor], Optional[Tensor]]: | |
if self.normalize_waveform: | |
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape) | |
return self.model.extract_features(waveforms, lengths, num_layers) | |
def _extend_model(module, normalize_waveform, apply_log_softmax=False, append_star=False): | |
"""Add extra transformations to the model""" | |
return _Wav2Vec2Model(module, normalize_waveform, apply_log_softmax, append_star) | |
def _remove_aux_axes(state_dict, axes): | |
# Remove the seemingly unnecessary axis | |
# For ASR task, the pretrained weights originated from fairseq has unrelated dimensions at index 1, 2, 3 | |
# It's originated from the Dictionary implementation of fairseq, which was intended for NLP tasks, | |
# but not used during the ASR training. | |
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37 | |
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L129 | |
# | |
# Also, some pretrained weights originated from voxpopuli has an extra dimensions that almost never used and | |
# that resembles mistake. | |
# The label `1` shows up in the training dataset of German (1 out of 16M), | |
# English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M) | |
for key in ["aux.weight", "aux.bias"]: | |
mat = state_dict[key] | |
state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes]) | |
def _get_state_dict(url, dl_kwargs, remove_axes=None): | |
if not url.startswith("https"): | |
url = f"https://download.pytorch.org/torchaudio/models/{url}" | |
dl_kwargs = {} if dl_kwargs is None else dl_kwargs | |
state_dict = load_state_dict_from_url(url, **dl_kwargs) | |
if remove_axes: | |
_remove_aux_axes(state_dict, remove_axes) | |
return state_dict | |
def _get_en_labels(): | |
return ( | |
"|", | |
"E", | |
"T", | |
"A", | |
"O", | |
"N", | |
"I", | |
"H", | |
"S", | |
"R", | |
"D", | |
"L", | |
"U", | |
"M", | |
"W", | |
"C", | |
"F", | |
"G", | |
"Y", | |
"P", | |
"B", | |
"V", | |
"K", | |
"'", | |
"X", | |
"J", | |
"Q", | |
"Z", | |
) | |
def _get_de_labels(): | |
return ( | |
"|", | |
"e", | |
"n", | |
"i", | |
"r", | |
"s", | |
"t", | |
"a", | |
"d", | |
"h", | |
"u", | |
"l", | |
"g", | |
"c", | |
"m", | |
"o", | |
"b", | |
"w", | |
"f", | |
"k", | |
"z", | |
"p", | |
"v", | |
"ü", | |
"ä", | |
"ö", | |
"j", | |
"ß", | |
"y", | |
"x", | |
"q", | |
) | |
def _get_vp_en_labels(): | |
return ( | |
"|", | |
"e", | |
"t", | |
"o", | |
"i", | |
"a", | |
"n", | |
"s", | |
"r", | |
"h", | |
"l", | |
"d", | |
"c", | |
"u", | |
"m", | |
"p", | |
"f", | |
"g", | |
"w", | |
"y", | |
"b", | |
"v", | |
"k", | |
"x", | |
"j", | |
"q", | |
"z", | |
) | |
def _get_es_labels(): | |
return ( | |
"|", | |
"e", | |
"a", | |
"o", | |
"s", | |
"n", | |
"r", | |
"i", | |
"l", | |
"d", | |
"c", | |
"t", | |
"u", | |
"p", | |
"m", | |
"b", | |
"q", | |
"y", | |
"g", | |
"v", | |
"h", | |
"ó", | |
"f", | |
"í", | |
"á", | |
"j", | |
"z", | |
"ñ", | |
"é", | |
"x", | |
"ú", | |
"k", | |
"w", | |
"ü", | |
) | |
def _get_fr_labels(): | |
return ( | |
"|", | |
"e", | |
"s", | |
"n", | |
"i", | |
"t", | |
"r", | |
"a", | |
"o", | |
"u", | |
"l", | |
"d", | |
"c", | |
"p", | |
"m", | |
"é", | |
"v", | |
"q", | |
"f", | |
"g", | |
"b", | |
"h", | |
"x", | |
"à", | |
"j", | |
"è", | |
"y", | |
"ê", | |
"z", | |
"ô", | |
"k", | |
"ç", | |
"œ", | |
"û", | |
"ù", | |
"î", | |
"â", | |
"w", | |
"ï", | |
"ë", | |
"ü", | |
"æ", | |
) | |
def _get_it_labels(): | |
return ( | |
"|", | |
"e", | |
"i", | |
"a", | |
"o", | |
"n", | |
"t", | |
"r", | |
"l", | |
"s", | |
"c", | |
"d", | |
"u", | |
"p", | |
"m", | |
"g", | |
"v", | |
"h", | |
"z", | |
"f", | |
"b", | |
"q", | |
"à", | |
"è", | |
"ù", | |
"é", | |
"ò", | |
"ì", | |
"k", | |
"y", | |
"x", | |
"w", | |
"j", | |
"ó", | |
"í", | |
"ï", | |
) | |
def _get_mms_labels(): | |
return ( | |
"a", | |
"i", | |
"e", | |
"n", | |
"o", | |
"u", | |
"t", | |
"s", | |
"r", | |
"m", | |
"k", | |
"l", | |
"d", | |
"g", | |
"h", | |
"y", | |
"b", | |
"p", | |
"w", | |
"c", | |
"v", | |
"j", | |
"z", | |
"f", | |
"'", | |
"q", | |
"x", | |
) | |