File size: 4,553 Bytes
4876346
 
 
 
 
 
a84d548
e45b214
4876346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a84d548
4876346
e45b214
 
 
 
 
 
 
4876346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e45b214
 
 
 
4876346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import torch
from transformers import AutoFeatureExtractor, WhisperModel, AutoModelForSpeechSeq2Seq
import numpy as np
import torchaudio
import librosa
import spaces

import gradio as gr
from modules import load_audio, MosPredictor, denorm


mos_checkpoint = "ckpt_mosa_net_plus"

print('Loading MOSANET+ checkpoint...')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = MosPredictor().to(device)
model.eval()
model.load_state_dict(torch.load(mos_checkpoint, map_location=device))

print('Loading Whisper checkpoint...')
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large-v3")
#model_asli = WhisperModel.from_pretrained("openai/whisper-large-v3") 
model_asli = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-large-v3", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="sdpa")  
model_asli = model_asli.to(device)

@spaces.GPU
def predict_mos(wavefile:str):
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    if device != model.device:
        model.to(device)
    if device != model_asli.device:
        model_asli.to(device)
    
    
    print('Starting prediction...')
    # STFT
    wav = torchaudio.load(wavefile)[0] 
    lps = torch.from_numpy(np.expand_dims(np.abs(librosa.stft(wav[0].detach().numpy(), n_fft = 512, hop_length=256,win_length=512)).T, axis=0))
    lps = lps.unsqueeze(1)

    # Whisper Feature
    audio = load_audio(wavefile)
    inputs = feature_extractor(audio, return_tensors="pt")
    input_features = inputs.input_features
    input_features = input_features.to(device)  

    with torch.no_grad():
        decoder_input_ids = torch.tensor([[1, 1]]) * model_asli.config.decoder_start_token_id
        decoder_input_ids =  decoder_input_ids.to(device)
        last_hidden_state = model_asli(input_features, decoder_input_ids=decoder_input_ids).encoder_last_hidden_state 
        whisper_feat = last_hidden_state

    print('Model features shapes...')
    print(whisper_feat.shape)
    print(wav.shape)
    print(lps.shape)

    # prediction
    wav = wav.to(device)
    lps = lps.to(device)
    Quality_1, Intell_1, frame1, frame2 = model(wav ,lps, whisper_feat)
    quality_pred = Quality_1.cpu().detach().numpy()[0] 
    intell_pred = Intell_1.cpu().detach().numpy()[0]         

    print("predictions")
    qa_text = f"Quality: {denorm(quality_pred)[0]:.2f}  Inteligibility: {intell_pred[0]:.2f}" 
    print(qa_text)
    return qa_text


title =  """
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
    <div
        style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
    > <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
        MOSA-Net Whisper features
    </h1> </div>
</div>
""" 

description = """
This is a demo of [MOSA-Net+](https://github.com/dhimasryan/MOSA-Net-Cross-Domain/tree/main/MOSA_Net%2B), an improved version of MOSA-
NET that predicts human-based speech quality and intelligibility. MOSA-Net+ uses Whisper to generate cross-domain features. The model employs a CNN-
BLSTM architecture with an attention mechanism and is trained using a multi-task learning approach to predict subjective listening test
scores.  
MOSA-Net+ was tested in the noisy-and-enhanced track of the VoiceMOS Challenge 2023, where it obtained the top-ranked performance among nine systems [full paper](https://arxiv.org/abs/2309.12766)
"""

article = """
If the model contributes to your research please cite the following work: 

R. E. Zezario, S. -W. Fu, F. Chen, C. -S. Fuh, H. -M. Wang and Y. Tsao, "Deep Learning-Based Non-Intrusive Multi-Objective Speech Assessment Model With Cross-Domain Features," in IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 31, pp. 54-70, 2023, doi: 10.1109/TASLP.2022.3205757.

R. E. Zezario, Y.-W. Chen, S.-W. Fu, Y. Tsao, H.-M. Wang, C.-S. Fuh, "A Study on Incorporating Whisper for Robust Speech Assessment," IEEE ICME 2024, July 2024, (Top Performance on the Track 3 - VoiceMOS Challenge 2023)"

demo contributed by [@wetdog](https://github.com/wetdog)
"""
demo = gr.Blocks()
with demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.Interface(
    fn=predict_mos,
    inputs=gr.Audio(type='filepath'),
    outputs="text",
    allow_flagging=False,)
    gr.Markdown(article)

demo.queue(max_size=10)
demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)