bel32123 commited on
Commit
13375b8
Β·
1 Parent(s): 6acf275

Add prompt generation feature and articulation videos lookup

Browse files
Files changed (1) hide show
  1. app.py +100 -5
app.py CHANGED
@@ -3,8 +3,14 @@ from speechbrain.pretrained import GraphemeToPhoneme
3
  import os
4
  import torchaudio
5
  from wav2vecasr.MispronounciationDetector import MispronounciationDetector
6
- from wav2vecasr.PhonemeASRModel import Wav2Vec2PhonemeASRModel, Wav2Vec2OptimisedPhonemeASRModel, MultitaskPhonemeASRModel
7
- import torch
 
 
 
 
 
 
8
 
9
  @st.cache_resource
10
  def load_model():
@@ -34,6 +40,46 @@ def get_audio(saved_sound_filename):
34
  audio = audio.view(audio.shape[1])
35
  return audio
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def mispronounciation_detection_section():
38
  st.write('# Prediction')
39
  st.write('1. Upload a recording of you saying the text in .wav format')
@@ -52,11 +98,13 @@ def mispronounciation_detection_section():
52
  # load model
53
  mispronunciation_detector = load_model()
54
 
55
- # start prediction
56
  st.write('# Detection Results')
57
  with st.spinner('Predicting...'):
 
 
58
  raw_info = mispronunciation_detector.detect(audio, text, phoneme_error_threshold=0.25)
59
 
 
60
  st.write('#### Phoneme Level Analysis')
61
  st.write(f"Phoneme Error Rate: {round(raw_info['per'],2)}")
62
  st.markdown(
@@ -76,9 +124,13 @@ def mispronounciation_detection_section():
76
  )
77
 
78
  st.divider()
 
 
79
  md = []
 
80
  for word, has_error in zip(raw_info["words"], raw_info["word_errors"]):
81
  if has_error:
 
82
  md.append(f"**{word}**")
83
  else:
84
  md.append(word)
@@ -86,19 +138,62 @@ def mispronounciation_detection_section():
86
  st.write('#### Word Level Analysis')
87
  st.write(f"Word Error Rate: {round(raw_info['wer'], 2)} and the following words in bold have errors:")
88
  st.markdown(" ".join(md))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  else:
90
  st.error('The audio or text has not been properly input', icon="🚨")
91
  return
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  if __name__ == '__main__':
94
  st.write('___')
95
  # create a sidebar
96
  st.sidebar.title('Pronounciation Evaluation')
97
- select = st.sidebar.selectbox('', ['Main Page', 'Mispronounciation Detection'], key='1', label_visibility='collapsed')
98
  st.sidebar.write(select)
99
  if select=='Mispronounciation Detection':
100
  mispronounciation_detection_section()
101
- # else: stay on the home page
 
102
  else:
103
  st.write('# Pronounciation Evaluation')
104
  st.write('This app is designed to detect mispronounciation of English words for English learners from Asian countries like Korean, Mandarin and Vietnameses.')
 
3
  import os
4
  import torchaudio
5
  from wav2vecasr.MispronounciationDetector import MispronounciationDetector
6
+ from wav2vecasr.PhonemeASRModel import MultitaskPhonemeASRModel
7
+ import json
8
+ import os
9
+ import random
10
+ import openai
11
+
12
+
13
+ openai.api_key = os.getenv("OPENAI_KEY")
14
 
15
  @st.cache_resource
16
  def load_model():
 
40
  audio = audio.view(audio.shape[1])
41
  return audio
42
 
43
+ @st.cache_data
44
+ def get_prompts():
45
+ prompts_path = os.path.join(os.getcwd(), "wav2vecasr", "data", "prompts.json")
46
+ f = open(prompts_path)
47
+ data = json.load(f)
48
+ prompts = data["prompts"]
49
+ return prompts
50
+
51
+ @st.cache_data
52
+ def get_articulation_videos():
53
+ # note -- not all arpabets could be mapped to a video with visualisation on articulation
54
+ path = os.path.join(os.getcwd(), "wav2vecasr", "data", "videos.json")
55
+ f = open(path)
56
+ data = json.load(f)
57
+ return data
58
+
59
+ def get_prompts_from_l2_arctic(prompts, current_prompt, num_to_get):
60
+ selected_prompts = []
61
+ while len(selected_prompts) < num_to_get:
62
+ prompt = random.choice(prompts)
63
+ if prompt not in selected_prompts and prompt != current_prompt:
64
+ selected_prompts.append(prompt)
65
+
66
+ return selected_prompts
67
+
68
+ def get_prompt_from_openai(words_with_error_list):
69
+ try:
70
+ words_with_errors = ", ".join(words_with_error_list)
71
+ response = openai.ChatCompletion.create(
72
+ model="gpt-3.5-turbo",
73
+ messages=[
74
+ {"role": "system", "content": "You are writing practice reading prompts for learners of English to practice pronunciation. These prompts should be short, easy to understand and useful."},
75
+ {"role": "user", "content": f"Write a short sentence of less than 10 words and include the following words in the sentence: {words_with_errors} No numbers."}
76
+ ]
77
+ )
78
+
79
+ return response['choices'][0]['message']['content']
80
+ except:
81
+ return ""
82
+
83
  def mispronounciation_detection_section():
84
  st.write('# Prediction')
85
  st.write('1. Upload a recording of you saying the text in .wav format')
 
98
  # load model
99
  mispronunciation_detector = load_model()
100
 
 
101
  st.write('# Detection Results')
102
  with st.spinner('Predicting...'):
103
+
104
+ # detect
105
  raw_info = mispronunciation_detector.detect(audio, text, phoneme_error_threshold=0.25)
106
 
107
+ # display prediction results for phonemes
108
  st.write('#### Phoneme Level Analysis')
109
  st.write(f"Phoneme Error Rate: {round(raw_info['per'],2)}")
110
  st.markdown(
 
124
  )
125
 
126
  st.divider()
127
+
128
+ # display word errors
129
  md = []
130
+ words_with_errors = []
131
  for word, has_error in zip(raw_info["words"], raw_info["word_errors"]):
132
  if has_error:
133
+ words_with_errors.append(word)
134
  md.append(f"**{word}**")
135
  else:
136
  md.append(word)
 
138
  st.write('#### Word Level Analysis')
139
  st.write(f"Word Error Rate: {round(raw_info['wer'], 2)} and the following words in bold have errors:")
140
  st.markdown(" ".join(md))
141
+
142
+ st.divider()
143
+
144
+ # display more prompts to practice -- 1 from ChatGPT -- based on user's mistakes, 2 from L2 Arctic
145
+ st.write('#### What is next?')
146
+
147
+ st.write('Here are some more prompts for you to practice:')
148
+
149
+ selected_prompts = []
150
+
151
+ unique_words_with_errors = list(set(words_with_errors))
152
+ prompt_for_mistakes_made = get_prompt_from_openai(unique_words_with_errors)
153
+ if prompt_for_mistakes_made:
154
+ selected_prompts.append(prompt_for_mistakes_made)
155
+
156
+ prompts = get_prompts()
157
+ l2_arctic_prompts = get_prompts_from_l2_arctic(prompts, text, 3-len(selected_prompts))
158
+ selected_prompts.extend(l2_arctic_prompts)
159
+
160
+ for prompt in selected_prompts:
161
+ st.code(f'''{prompt}''', language="python")
162
+
163
+
164
  else:
165
  st.error('The audio or text has not been properly input', icon="🚨")
166
  return
167
 
168
+ def video_section():
169
+ st.write('# Get helpful videos on phoneme articulation')
170
+
171
+ problem_phoneme = st.text_input(
172
+ "Enter the phoneme you had problems with πŸ‘‡"
173
+ )
174
+
175
+ arpabet_to_video_map = get_articulation_videos()
176
+
177
+ if st.button('Look up'):
178
+ if not problem_phoneme:
179
+ st.error('The audio or text has not been properly input', icon="🚨")
180
+ elif problem_phoneme in arpabet_to_video_map:
181
+ video_link = arpabet_to_video_map[problem_phoneme]["link"]
182
+ if video_link:
183
+ st.video(video_link)
184
+ else:
185
+ st.write("Sorry, we couldn't find a good enough video yet :( we are working on it!")
186
+
187
  if __name__ == '__main__':
188
  st.write('___')
189
  # create a sidebar
190
  st.sidebar.title('Pronounciation Evaluation')
191
+ select = st.sidebar.selectbox('', ['Main Page', 'Mispronounciation Detection', 'Helpful Videos for Problem Phonemes'], key='1', label_visibility='collapsed')
192
  st.sidebar.write(select)
193
  if select=='Mispronounciation Detection':
194
  mispronounciation_detection_section()
195
+ elif select=="Helpful Videos for Problem Phonemes":
196
+ video_section()
197
  else:
198
  st.write('# Pronounciation Evaluation')
199
  st.write('This app is designed to detect mispronounciation of English words for English learners from Asian countries like Korean, Mandarin and Vietnameses.')