File size: 1,786 Bytes
57926d1
 
 
a05c869
0448aa2
 
a05c869
57926d1
a05c869
57926d1
 
0448aa2
57926d1
 
3b8d409
57926d1
 
3b8d409
57926d1
 
0448aa2
57926d1
 
 
 
 
0448aa2
57926d1
 
 
0448aa2
3b8d409
 
0448aa2
d32240b
3b8d409
 
0448aa2
 
 
 
3b8d409
0448aa2
57926d1
0448aa2
3b8d409
57926d1
0448aa2
3b8d409
0448aa2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import nltk
import librosa
import torch
import gradio as gr
from pyctcdecode import build_ctcdecoder
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC

nltk.download("punkt")

#Loading the model and the tokenizer
model_name = "facebook/wav2vec2-base-960h"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

def load_and_fix_data(input_file):  
  #read the file
  speech, sample_rate = librosa.load(input_file)
  #make it 1D
  if len(speech.shape) > 1: 
      speech = speech[:,0] + speech[:,1]
  #resampling to 16KHz
  if sample_rate !=16000:
    speech = librosa.resample(speech, sample_rate,16000)
  return speech
  
  
def fix_transcription_casing(input_sentence):
  sentences = nltk.sent_tokenize(input_sentence)
  return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
  
def predict_and_decode(input_file):
  speech = load_and_fix_data(input_file)

  input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
  logits = model(input_values).logits.cpu().detach().numpy()[0]
  
  vocab_list = list(processor.tokenizer.get_vocab().keys())  
  decoder = build_ctcdecoder(vocab_list)
  pred = decoder.decode(logits)

  transcribed_text = fix_transcription_casing(pred.lower())

  return transcribed_text
  
gr.Interface(predict_and_decode,
             inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Record/ Drop audio"),
             outputs = gr.outputs.Textbox(label="Output Text"),
             title="ASR using Wav2Vec 2.0 & pyctcdecode",
             description = "Extending HF ASR models with pyctcdecode decoder",
             layout = "horizontal",
             examples = [["test.wav"]], theme="huggingface").launch()