MOSA-Net_plus / app.py
wetdog's picture
add zero spaces decorator
a84d548
raw
history blame
4.55 kB
import os
import torch
from transformers import AutoFeatureExtractor, WhisperModel, AutoModelForSpeechSeq2Seq
import numpy as np
import torchaudio
import librosa
import spaces
import gradio as gr
from modules import load_audio, MosPredictor, denorm
mos_checkpoint = "ckpt_mosa_net_plus"
print('Loading MOSANET+ checkpoint...')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = MosPredictor().to(device)
model.eval()
model.load_state_dict(torch.load(mos_checkpoint, map_location=device))
print('Loading Whisper checkpoint...')
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large-v3")
#model_asli = WhisperModel.from_pretrained("openai/whisper-large-v3")
model_asli = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-large-v3", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="sdpa")
model_asli = model_asli.to(device)
@spaces.GPU
def predict_mos(wavefile:str):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
if device != model_asli.device:
model_asli.to(device)
print('Starting prediction...')
# STFT
wav = torchaudio.load(wavefile)[0]
lps = torch.from_numpy(np.expand_dims(np.abs(librosa.stft(wav[0].detach().numpy(), n_fft = 512, hop_length=256,win_length=512)).T, axis=0))
lps = lps.unsqueeze(1)
# Whisper Feature
audio = load_audio(wavefile)
inputs = feature_extractor(audio, return_tensors="pt")
input_features = inputs.input_features
input_features = input_features.to(device)
with torch.no_grad():
decoder_input_ids = torch.tensor([[1, 1]]) * model_asli.config.decoder_start_token_id
decoder_input_ids = decoder_input_ids.to(device)
last_hidden_state = model_asli(input_features, decoder_input_ids=decoder_input_ids).encoder_last_hidden_state
whisper_feat = last_hidden_state
print('Model features shapes...')
print(whisper_feat.shape)
print(wav.shape)
print(lps.shape)
# prediction
wav = wav.to(device)
lps = lps.to(device)
Quality_1, Intell_1, frame1, frame2 = model(wav ,lps, whisper_feat)
quality_pred = Quality_1.cpu().detach().numpy()[0]
intell_pred = Intell_1.cpu().detach().numpy()[0]
print("predictions")
qa_text = f"Quality: {denorm(quality_pred)[0]:.2f} Inteligibility: {intell_pred[0]:.2f}"
print(qa_text)
return qa_text
title = """
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<div
style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
> <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
MOSA-Net Whisper features
</h1> </div>
</div>
"""
description = """
This is a demo of [MOSA-Net+](https://github.com/dhimasryan/MOSA-Net-Cross-Domain/tree/main/MOSA_Net%2B), an improved version of MOSA-
NET that predicts human-based speech quality and intelligibility. MOSA-Net+ uses Whisper to generate cross-domain features. The model employs a CNN-
BLSTM architecture with an attention mechanism and is trained using a multi-task learning approach to predict subjective listening test
scores.
MOSA-Net+ was tested in the noisy-and-enhanced track of the VoiceMOS Challenge 2023, where it obtained the top-ranked performance among nine systems [full paper](https://arxiv.org/abs/2309.12766)
"""
article = """
If the model contributes to your research please cite the following work:
R. E. Zezario, S. -W. Fu, F. Chen, C. -S. Fuh, H. -M. Wang and Y. Tsao, "Deep Learning-Based Non-Intrusive Multi-Objective Speech Assessment Model With Cross-Domain Features," in IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 31, pp. 54-70, 2023, doi: 10.1109/TASLP.2022.3205757.
R. E. Zezario, Y.-W. Chen, S.-W. Fu, Y. Tsao, H.-M. Wang, C.-S. Fuh, "A Study on Incorporating Whisper for Robust Speech Assessment," IEEE ICME 2024, July 2024, (Top Performance on the Track 3 - VoiceMOS Challenge 2023)"
demo contributed by [@wetdog](https://github.com/wetdog)
"""
demo = gr.Blocks()
with demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Interface(
fn=predict_mos,
inputs=gr.Audio(type='filepath'),
outputs="text",
allow_flagging=False,)
gr.Markdown(article)
demo.queue(max_size=10)
demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)