Yehor commited on
Commit
0edad78
·
verified ·
1 Parent(s): 0919331

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -12
app.py CHANGED
@@ -9,10 +9,10 @@ import torchaudio.transforms as T
9
 
10
  import gradio as gr
11
 
12
- from transformers import AutoModelForCTC, Wav2Vec2BertProcessor
13
 
14
  # Config
15
- model_name = "Yehor/w2v-bert-2.0-uk-v2"
16
 
17
  min_duration = 0.5
18
  max_duration = 60
@@ -25,10 +25,8 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
26
 
27
  # Load the model
28
- asr_model = AutoModelForCTC.from_pretrained(model_name, torch_dtype=torch_dtype).to(
29
- device
30
- )
31
- processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
32
 
33
  if use_torch_compile:
34
  asr_model = torch.compile(asr_model)
@@ -156,13 +154,10 @@ def inference(audio_path, progress=gr.Progress()):
156
  resampler = T.Resample(sr, 16_000, dtype=audio_input.dtype)
157
  audio_input = resampler(audio_input)
158
 
159
- audio_input = audio_input.squeeze().numpy()
160
-
161
- features = processor([audio_input], sampling_rate=16_000).input_features
162
- features = torch.tensor(features).to(device)
163
 
164
- if torch_dtype == torch.float16:
165
- features = features.half()
166
 
167
  with torch.inference_mode():
168
  logits = asr_model(features).logits
 
9
 
10
  import gradio as gr
11
 
12
+ from transformers import HubertForCTC, Wav2Vec2Processor
13
 
14
  # Config
15
+ model_name = "Yehor/mHuBERT-147-uk"
16
 
17
  min_duration = 0.5
18
  max_duration = 60
 
25
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
26
 
27
  # Load the model
28
+ asr_model = HubertForCTC.from_pretrained(model_name, torch_dtype=torch_dtype, device_map=device)
29
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
 
 
30
 
31
  if use_torch_compile:
32
  asr_model = torch.compile(asr_model)
 
154
  resampler = T.Resample(sr, 16_000, dtype=audio_input.dtype)
155
  audio_input = resampler(audio_input)
156
 
157
+ audio_input = audio_input.squeeze(0).numpy()
 
 
 
158
 
159
+ inputs = processor([audio_input], sampling_rate=16_000, padding=True).input_values
160
+ features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device)
161
 
162
  with torch.inference_mode():
163
  logits = asr_model(features).logits