bel32123 commited on
Commit
1e93f37
·
1 Parent(s): 2114839

Adjust to use multitask model

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -3,17 +3,17 @@ from speechbrain.pretrained import GraphemeToPhoneme
3
  import os
4
  import torchaudio
5
  from wav2vecasr.MispronounciationDetector import MispronounciationDetector
6
- from wav2vecasr.PhonemeASRModel import Wav2Vec2PhonemeASRModel, Wav2Vec2OptimisedPhonemeASRModel
7
-
8
 
9
  @st.cache_resource
10
  def load_model():
11
- path = os.path.join(os.getcwd(), "wav2vecasr", "model", "checkpoint-600")
12
- asr_model = Wav2Vec2OptimisedPhonemeASRModel(path, os.path.join(path, "wav2vec2_vocab_final.json"),
13
- os.path.join(os.getcwd(), "wav2vecasr", "pretrained_models",
14
- "en-kenlm-model", "en.arpa.bin"))
15
  g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
16
- mispronounciation_detector = MispronounciationDetector(asr_model, g2p, "cpu")
17
  return mispronounciation_detector
18
 
19
 
@@ -55,12 +55,10 @@ def mispronounciation_detection_section():
55
  # start prediction
56
  st.write('# Detection Results')
57
  with st.spinner('Predicting...'):
58
- raw_info = mispronunciation_detector.detect(audio, text)
59
 
60
  st.write('#### Phoneme Level Analysis')
61
  st.write(f"Phoneme Error Rate: {round(raw_info['per'],2)}")
62
- # enable horizontal scrolling for phoneme output
63
- #st.text_area(label="Aligned phoneme outputs", value=raw_info['phoneme_output'],height=150)
64
  st.markdown(
65
  f"""
66
  <style>
@@ -69,9 +67,9 @@ def mispronounciation_detection_section():
69
  }}
70
  </style>
71
  ```
72
- {" ".join(raw_info['ref'])}
73
- {" ".join(raw_info['hyp'])}
74
- {" ".join(raw_info['phoneme_errors'])}
75
  ```
76
  """,
77
  unsafe_allow_html=True,
 
3
  import os
4
  import torchaudio
5
  from wav2vecasr.MispronounciationDetector import MispronounciationDetector
6
+ from wav2vecasr.PhonemeASRModel import Wav2Vec2PhonemeASRModel, Wav2Vec2OptimisedPhonemeASRModel, MultitaskPhonemeASRModel
7
+ import torch
8
 
9
  @st.cache_resource
10
  def load_model():
11
+ path = os.path.join(os.getcwd(), "wav2vecasr", "model", "multitask_best_ctc.pt")
12
+ vocab_path = os.path.join(os.getcwd(), "wav2vecasr", "model", "vocab")
13
+ device = "cpu"
14
+ asr_model = MultitaskPhonemeASRModel(path, vocab_path, device)
15
  g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
16
+ mispronounciation_detector = MispronounciationDetector(asr_model, g2p, device)
17
  return mispronounciation_detector
18
 
19
 
 
55
  # start prediction
56
  st.write('# Detection Results')
57
  with st.spinner('Predicting...'):
58
+ raw_info = mispronunciation_detector.detect(audio, text, phoneme_error_threshold=0.25)
59
 
60
  st.write('#### Phoneme Level Analysis')
61
  st.write(f"Phoneme Error Rate: {round(raw_info['per'],2)}")
 
 
62
  st.markdown(
63
  f"""
64
  <style>
 
67
  }}
68
  </style>
69
  ```
70
+ {raw_info['ref']}
71
+ {raw_info['hyp']}
72
+ {raw_info['phoneme_errors']}
73
  ```
74
  """,
75
  unsafe_allow_html=True,