|
|
|
import librosa |
|
import torch |
|
from transformers import Wav2Vec2ForCTC, AutoProcessor |
|
from transformers import set_seed |
|
import time |
|
|
|
|
|
def transcribe(fp:str, target_lang:str) -> str: |
|
''' |
|
For given audio file, transcribe it. |
|
|
|
Parameters |
|
---------- |
|
fp: str |
|
The file path to the audio file. |
|
target_lang:str |
|
The ISO-3 code of the target language. |
|
|
|
Returns |
|
---------- |
|
transcript:str |
|
The transcribed text. |
|
''' |
|
|
|
set_seed(555) |
|
start_time = time.time() |
|
|
|
|
|
model_id = "facebook/mms-1b-all" |
|
|
|
processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang) |
|
model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang=target_lang, ignore_mismatched_sizes=True) |
|
|
|
|
|
signal, sampling_rate = librosa.load(fp, sr=16000) |
|
inputs = processor(signal, sampling_rate=16_000, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs).logits |
|
|
|
ids = torch.argmax(outputs, dim=-1)[0] |
|
transcript = processor.decode(ids) |
|
|
|
print("Time elapsed: ", int(time.time() - start_time), " seconds") |
|
return transcript |