Spaces:
Runtime error
Runtime error
Adjust to use multitask model
Browse files
app.py
CHANGED
@@ -3,17 +3,17 @@ from speechbrain.pretrained import GraphemeToPhoneme
|
|
3 |
import os
|
4 |
import torchaudio
|
5 |
from wav2vecasr.MispronounciationDetector import MispronounciationDetector
|
6 |
-
from wav2vecasr.PhonemeASRModel import Wav2Vec2PhonemeASRModel, Wav2Vec2OptimisedPhonemeASRModel
|
7 |
-
|
8 |
|
9 |
@st.cache_resource
|
10 |
def load_model():
|
11 |
-
path = os.path.join(os.getcwd(), "wav2vecasr", "model", "
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
|
16 |
-
mispronounciation_detector = MispronounciationDetector(asr_model, g2p,
|
17 |
return mispronounciation_detector
|
18 |
|
19 |
|
@@ -55,12 +55,10 @@ def mispronounciation_detection_section():
|
|
55 |
# start prediction
|
56 |
st.write('# Detection Results')
|
57 |
with st.spinner('Predicting...'):
|
58 |
-
raw_info = mispronunciation_detector.detect(audio, text)
|
59 |
|
60 |
st.write('#### Phoneme Level Analysis')
|
61 |
st.write(f"Phoneme Error Rate: {round(raw_info['per'],2)}")
|
62 |
-
# enable horizontal scrolling for phoneme output
|
63 |
-
#st.text_area(label="Aligned phoneme outputs", value=raw_info['phoneme_output'],height=150)
|
64 |
st.markdown(
|
65 |
f"""
|
66 |
<style>
|
@@ -69,9 +67,9 @@ def mispronounciation_detection_section():
|
|
69 |
}}
|
70 |
</style>
|
71 |
```
|
72 |
-
{
|
73 |
-
{
|
74 |
-
{
|
75 |
```
|
76 |
""",
|
77 |
unsafe_allow_html=True,
|
|
|
3 |
import os
|
4 |
import torchaudio
|
5 |
from wav2vecasr.MispronounciationDetector import MispronounciationDetector
|
6 |
+
from wav2vecasr.PhonemeASRModel import Wav2Vec2PhonemeASRModel, Wav2Vec2OptimisedPhonemeASRModel, MultitaskPhonemeASRModel
|
7 |
+
import torch
|
8 |
|
9 |
@st.cache_resource
|
10 |
def load_model():
|
11 |
+
path = os.path.join(os.getcwd(), "wav2vecasr", "model", "multitask_best_ctc.pt")
|
12 |
+
vocab_path = os.path.join(os.getcwd(), "wav2vecasr", "model", "vocab")
|
13 |
+
device = "cpu"
|
14 |
+
asr_model = MultitaskPhonemeASRModel(path, vocab_path, device)
|
15 |
g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
|
16 |
+
mispronounciation_detector = MispronounciationDetector(asr_model, g2p, device)
|
17 |
return mispronounciation_detector
|
18 |
|
19 |
|
|
|
55 |
# start prediction
|
56 |
st.write('# Detection Results')
|
57 |
with st.spinner('Predicting...'):
|
58 |
+
raw_info = mispronunciation_detector.detect(audio, text, phoneme_error_threshold=0.25)
|
59 |
|
60 |
st.write('#### Phoneme Level Analysis')
|
61 |
st.write(f"Phoneme Error Rate: {round(raw_info['per'],2)}")
|
|
|
|
|
62 |
st.markdown(
|
63 |
f"""
|
64 |
<style>
|
|
|
67 |
}}
|
68 |
</style>
|
69 |
```
|
70 |
+
{raw_info['ref']}
|
71 |
+
{raw_info['hyp']}
|
72 |
+
{raw_info['phoneme_errors']}
|
73 |
```
|
74 |
""",
|
75 |
unsafe_allow_html=True,
|