Spaces:
Runtime error
Runtime error
File size: 3,878 Bytes
57926d1 a05c869 0448aa2 17a49e1 a05c869 57926d1 a05c869 208ffe2 851eb15 208ffe2 8d69919 7806ecb 57926d1 3b8d409 57926d1 379fa33 0448aa2 57926d1 8d69919 3b8d409 0448aa2 d32240b 3b8d409 0448aa2 3b8d409 0448aa2 379fa33 b8af00e 8d69919 379fa33 bbbf923 379fa33 decaa84 b8af00e 379fa33 8a068ad 607a780 b8af00e 8d69919 7806ecb 0448aa2 be02097 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import nltk
import librosa
import torch
import gradio as gr
from pyctcdecode import build_ctcdecoder
from transformers import Wav2Vec2Processor, AutoModelForCTC
nltk.download("punkt")
wav2vec2processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
wav2vec2model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h")
hubertprocessor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
hubertmodel = AutoModelForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
def return_processor_and_model(model_name):
return Wav2Vec2Processor.from_pretrained(model_name), AutoModelForCTC.from_pretrained(model_name)
def load_and_fix_data(input_file):
speech, sample_rate = librosa.load(input_file)
if len(speech.shape) > 1:
speech = speech[:,0] + speech[:,1]
if sample_rate !=16000:
speech = librosa.resample(speech, sample_rate,16000)
return speech
def fix_transcription_casing(input_sentence):
sentences = nltk.sent_tokenize(input_sentence)
return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
def predict_and_ctc_decode(input_file, model_name):
processor, model = return_processor_and_model(model_name)
speech = load_and_fix_data(input_file)
input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
logits = model(input_values).logits.cpu().detach().numpy()[0]
vocab_list = list(processor.tokenizer.get_vocab().keys())
decoder = build_ctcdecoder(vocab_list)
pred = decoder.decode(logits)
transcribed_text = fix_transcription_casing(pred.lower())
return transcribed_text
def predict_and_ctc_lm_decode(input_file, model_name):
processor, model = return_processor_and_model(model_name)
speech = load_and_fix_data(input_file)
input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
logits = model(input_values).logits.cpu().detach().numpy()[0]
vocab_list = list(processor.tokenizer.get_vocab().keys())
vocab_dict = processor.tokenizer.get_vocab()
sorted_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
decoder = build_ctcdecoder(
list(sorted_dict.keys()),
"4gram_small.arpa.gz",
)
pred = decoder.decode(logits)
transcribed_text = fix_transcription_casing(pred.lower())
return transcribed_text
def predict_and_greedy_decode(input_file, model_name):
processor, model = return_processor_and_model(model_name)
speech = load_and_fix_data(input_file)
input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
pred = processor.batch_decode(predicted_ids)
transcribed_text = fix_transcription_casing(pred[0].lower())
return transcribed_text
def return_all_predictions(input_file, model_name):
return predict_and_ctc_decode(input_file, model_name), predict_and_ctc_lm_decode(input_file, model_name), predict_and_greedy_decode(input_file, model_name)
gr.Interface(return_all_predictions,
inputs = [gr.inputs.Audio(source="microphone", type="filepath", label="Record/ Drop audio"), gr.inputs.Dropdown(["facebook/wav2vec2-base-960h", "facebook/hubert-large-ls960-ft"], label="Model Name")],
outputs = [gr.outputs.Textbox(label="Beam CTC decoding"), gr.outputs.Textbox(label="Beam CTC decoding w/ LM"), gr.outputs.Textbox(label="Greedy decoding")],
title="ASR using Wav2Vec2/ Hubert & pyctcdecode",
description = "Comparing greedy decoder with beam search CTC decoder, record/ drop your audio!",
layout = "horizontal",
examples = [["test1.wav", "facebook/wav2vec2-base-960h"], ["test2.wav", "facebook/hubert-large-ls960-ft"]],
theme="huggingface",
enable_queue=True).launch() |