davidmeikle commited on
Commit
614dc5d
·
verified ·
1 Parent(s): 851cec6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -18
app.py CHANGED
@@ -125,7 +125,9 @@ class PhonemeTranscriber:
125
  self.device = self._get_optimal_device()
126
  print(f"Using device: {self.device}")
127
 
128
- self.model_config = self._initialize_model()
 
 
129
  self.target_sample_rate = 16_000
130
  self.enhancer = PhoneticEnhancer()
131
 
@@ -135,18 +137,6 @@ class PhonemeTranscriber:
135
  elif torch.backends.mps.is_available() and platform.system() == 'Darwin':
136
  return "mps"
137
  return "cpu"
138
-
139
- def _initialize_model(self) -> ModelConfig:
140
- model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
141
- processor = Wav2Vec2Processor.from_pretrained(model_name)
142
- model = Wav2Vec2ForCTC.from_pretrained(model_name)
143
-
144
- return ModelConfig(
145
- name=model_name,
146
- processor=processor,
147
- model=model,
148
- description="LV-60 + CommonVoice (26 langs) + eSpeak"
149
- )
150
 
151
  def preprocess_audio(self, audio):
152
  """Preprocess audio data for model input."""
@@ -180,8 +170,13 @@ class PhonemeTranscriber:
180
  audio_data = self.preprocess_audio(audio)
181
  if audio_data is None:
182
  return "Please provide valid audio input"
 
 
 
 
 
183
  selected_enhancements = enhancements.split(',') if enhancements else []
184
- inputs = self.model_config.processor(
185
  audio_data,
186
  sampling_rate=self.target_sample_rate,
187
  return_tensors="pt",
@@ -189,19 +184,24 @@ class PhonemeTranscriber:
189
  ).input_values.to(self.device)
190
 
191
  with torch.no_grad():
192
- logits = self.model_config.model(inputs).logits
193
 
194
  predicted_ids = torch.argmax(logits, dim=-1)
195
- transcription = self.model_config.processor.batch_decode(predicted_ids)[0]
196
 
197
  enhanced = self.enhancer.enhance_transcription(
198
  transcription,
199
  selected_enhancements
200
  )
201
 
 
 
 
 
 
202
  return f"""Raw IPA: {transcription}
203
- Enhanced IPA: {enhanced}
204
- Applied enhancements: {', '.join(selected_enhancements) or 'none'}"""
205
 
206
  except Exception as e:
207
  import traceback
 
125
  self.device = self._get_optimal_device()
126
  print(f"Using device: {self.device}")
127
 
128
+ # Store model name and initialize processor only
129
+ self.model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
130
+ self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
131
  self.target_sample_rate = 16_000
132
  self.enhancer = PhoneticEnhancer()
133
 
 
137
  elif torch.backends.mps.is_available() and platform.system() == 'Darwin':
138
  return "mps"
139
  return "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  def preprocess_audio(self, audio):
142
  """Preprocess audio data for model input."""
 
170
  audio_data = self.preprocess_audio(audio)
171
  if audio_data is None:
172
  return "Please provide valid audio input"
173
+
174
+ # Load model inside GPU context
175
+ model = Wav2Vec2ForCTC.from_pretrained(self.model_name).to(self.device)
176
+ model.eval()
177
+
178
  selected_enhancements = enhancements.split(',') if enhancements else []
179
+ inputs = self.processor(
180
  audio_data,
181
  sampling_rate=self.target_sample_rate,
182
  return_tensors="pt",
 
184
  ).input_values.to(self.device)
185
 
186
  with torch.no_grad():
187
+ logits = model(inputs).logits
188
 
189
  predicted_ids = torch.argmax(logits, dim=-1)
190
+ transcription = self.processor.batch_decode(predicted_ids)[0]
191
 
192
  enhanced = self.enhancer.enhance_transcription(
193
  transcription,
194
  selected_enhancements
195
  )
196
 
197
+ # Clean up to free GPU memory
198
+ del model
199
+ if torch.cuda.is_available():
200
+ torch.cuda.empty_cache()
201
+
202
  return f"""Raw IPA: {transcription}
203
+ Enhanced IPA: {enhanced}
204
+ Applied enhancements: {', '.join(selected_enhancements) or 'none'}"""
205
 
206
  except Exception as e:
207
  import traceback