Last commit not found
from random import sample | |
import gradio as gr | |
import torchaudio | |
import torch | |
import torch.nn as nn | |
import lightning_module | |
import pdb | |
import jiwer | |
from local.convert_metrics import nat2avaMOS, WER2INTELI | |
# ASR part | |
from transformers import pipeline | |
# p = pipeline("automatic-speech-recognition") | |
p = pipeline( | |
"automatic-speech-recognition", | |
model="KevinGeng/whipser_medium_en_PAL300_step25", | |
) | |
# WER part | |
transformation = jiwer.Compose([ | |
jiwer.ToLowerCase(), | |
jiwer.RemoveWhiteSpace(replace_by_space=True), | |
jiwer.RemoveMultipleSpaces(), | |
jiwer.ReduceToListOfListOfWords(word_delimiter=" ") | |
]) | |
# WPM part | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft") | |
phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft") | |
# phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft") | |
class ChangeSampleRate(nn.Module): | |
def __init__(self, input_rate: int, output_rate: int): | |
super().__init__() | |
self.output_rate = output_rate | |
self.input_rate = input_rate | |
def forward(self, wav: torch.tensor) -> torch.tensor: | |
# Only accepts 1-channel waveform input | |
wav = wav.view(wav.size(0), -1) | |
new_length = wav.size(-1) * self.output_rate // self.input_rate | |
indices = (torch.arange(new_length) * (self.input_rate / self.output_rate)) | |
round_down = wav[:, indices.long()] | |
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)] | |
output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0) | |
return output | |
model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval() | |
def calc_mos(audio_path, ref): | |
wav, sr = torchaudio.load(audio_path, channels_first=True) | |
if wav.shape[0] > 1: | |
wav = wav.mean(dim=0, keepdim=True) # Mono channel | |
osr = 16_000 | |
batch = wav.unsqueeze(0).repeat(10, 1, 1) | |
csr = ChangeSampleRate(sr, osr) | |
out_wavs = csr(wav) | |
# ASR | |
trans = p(audio_path)["text"] | |
# WER | |
wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation) | |
# WER convert to Intellibility score | |
INTELI_score = WER2INTELI(wer*100) | |
# MOS | |
batch = { | |
'wav': out_wavs, | |
'domains': torch.tensor([0]), | |
'judge_id': torch.tensor([288]) | |
} | |
with torch.no_grad(): | |
output = model(batch) | |
predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3 | |
# MOS to AVA MOS | |
AVA_MOS = nat2avaMOS(predic_mos) | |
# Phonemes per minute (PPM) | |
with torch.no_grad(): | |
logits = phoneme_model(out_wavs).logits | |
phone_predicted_ids = torch.argmax(logits, dim=-1) | |
phone_transcription = processor.batch_decode(phone_predicted_ids) | |
lst_phonemes = phone_transcription[0].split(" ") | |
wav_vad = torchaudio.functional.vad(wav, sample_rate=sr) | |
ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60 | |
return AVA_MOS, trans, INTELI_score, phone_transcription, ppm | |
description =""" | |
MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset. | |
This demo only accepts .wav format. Best at 16 kHz sampling rate. | |
Paper is available [here](https://arxiv.org/abs/2204.02152) | |
Add ASR based on wav2vec-960, currently only English available. | |
Add WER interface. | |
""" | |
iface = gr.Interface( | |
fn=calc_mos, | |
inputs=[gr.Audio(type='filepath', label="Audio to evaluate"), | |
gr.Textbox(placeholder="Input reference here (Don't keep this empty)", label="Reference")], | |
outputs=[gr.Textbox(placeholder="Naturalness Score, ranged from 0 to 5, the higher the better.", label="Naturalness Score, ranged from 0 to 5, the higher the better."), | |
gr.Textbox(placeholder="Hypothesis", label="Hypothesis"), | |
gr.Textbox(placeholder="Intelligibility Score", label = "Intelligibility Score, range from 0 to 100, the higher the better"), | |
gr.Textbox(placeholder="Predicted Phonemes", label="Predicted Phonemes"), | |
gr.Textbox(placeholder="Speaking Rate, Phonemes per minutes", label="PPM")], | |
title="Laronix's Voice Quality Checking System Demo", | |
description=description, | |
allow_flagging="auto", | |
) | |
iface.launch() |