whisper_asr / app.py
ahmedghani's picture
whisper demo added
8d39dd5
raw
history blame
4.86 kB
import whisper
import torch
import torchaudio
import streamlit as st
LANGUAGES = {
"english":"en",
"chinese":"zh",
"german":"de",
"spanish":"es",
"russian":"ru",
"korean":"ko",
"french":"fr",
"japanese":"ja",
"portuguese":"pt",
"turkish":"tr",
"polish":"pl",
"catalan":"ca",
"dutch":"nl",
"arabic":"ar",
"swedish":"sv",
"italian":"it",
"indonesian":"id",
"hindi":"hi",
"finnish":"fi",
"vietnamese":"vi",
"hebrew":"iw",
"ukrainian":"uk",
"greek":"el",
"malay":"ms",
"czech":"cs",
"romanian":"ro",
"danish":"da",
"hungarian":"hu",
"tamil":"ta",
"norwegian":"no",
"thai":"th",
"urdu":"ur",
"croatian":"hr",
"bulgarian":"bg",
"lithuanian":"lt",
"latin":"la",
"maori":"mi",
"malayalam":"ml",
"welsh":"cy",
"slovak":"sk",
"telugu":"te",
"persian":"fa",
"latvian":"lv",
"bengali":"bn",
"serbian":"sr",
"azerbaijani":"az",
"slovenian":"sl",
"kannada":"kn",
"estonian":"et",
"macedonian":"mk",
"breton":"br",
"basque":"eu",
"icelandic":"is",
"armenian":"hy",
"nepali":"ne",
"mongolian":"mn",
"bosnian":"bs",
"kazakh":"kk",
"albanian":"sq",
"swahili":"sw",
"galician":"gl",
"marathi":"mr",
"punjabi":"pa",
"sinhala":"si",
"khmer":"km",
"shona":"sn",
"yoruba":"yo",
"somali":"so",
"afrikaans":"af",
"occitan":"oc",
"georgian":"ka",
"belarusian":"be",
"tajik":"tg",
"sindhi":"sd",
"gujarati":"gu",
"amharic":"am",
"yiddish":"yi",
"lao":"lo",
"uzbek":"uz",
"faroese":"fo",
"haitian creole":"ht",
"pashto":"ps",
"turkmen":"tk",
"nynorsk":"nn",
"maltese":"mt",
"sanskrit":"sa",
"luxembourgish":"lb",
"myanmar":"my",
"tibetan":"bo",
"tagalog":"tl",
"malagasy":"mg",
"assamese":"as",
"tatar":"tt",
"hawaiian":"haw",
"lingala":"ln",
"hausa":"ha",
"bashkir":"ba",
"javanese":"jw",
"sundanese":"su",
}
def decode(model, mel, options):
result = whisper.decode(model, mel, options)
return result.text
def load_audio(path):
waveform, sample_rate = torchaudio.load(path)
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
return waveform.squeeze(0)
def detect_language(model, mel):
_, probs = model.detect_language(mel)
return max(probs, key=probs.get)
def main():
st.title("Whisper ASR Demo")
st.markdown(
"""
This is a demo of OpenAI's Whisper ASR model. The model is trained on 680,000 hours of dataset.
"""
)
model_selection = st.sidebar.selectbox("Select model", ["tiny", "base", "small", "medium", "large"])
en_model_selection = st.sidebar.checkbox("English only model", value=False)
if en_model_selection:
model_selection += ".en"
st.sidebar.write(f"Model: {model_selection+' (Multilingual)' if not en_model_selection else model_selection + ' (English only)'}")
if st.sidebar.checkbox("Show supported languages", value=False):
st.sidebar.info(list(LANGUAGES.keys()))
st.sidebar.title("Options")
beam_size = st.sidebar.slider("Beam Size", min_value=1, max_value=10, value=5)
fp16 = st.sidebar.checkbox("Enable FP16 for faster transcription (It may affect performance)", value=False)
if not en_model_selection:
task = st.sidebar.selectbox("Select task", ["transcribe", "translate (To English)"], index=0)
else:
task = st.sidebar.selectbox("Select task", ["transcribe"], index=0)
st.title("Audio")
audio_file = st.file_uploader("Upload Audio", type=["wav", "mp3", "flac"])
if audio_file is not None:
st.audio(audio_file, format='audio/ogg')
with st.spinner("Loading model..."):
model = whisper.load_model(model_selection)
model = model.to("cpu") if not torch.cuda.is_available() else model.to("cuda")
audio = load_audio(audio_file)
with st.spinner("Extracting features..."):
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).to(model.device)
if not en_model_selection:
with st.spinner("Detecting language..."):
language = detect_language(model, mel)
st.markdown(f"Detected Language: {language}")
else:
language = "en"
configuration = {"beam_size": beam_size, "fp16": fp16, "task": task, "language": language}
with st.spinner("Transcribing..."):
options = whisper.DecodingOptions(**configuration)
text = decode(model, mel, options)
st.markdown(f"**Recognized Text:** {text}")
if __name__ == "__main__":
main()