|
import torch |
|
import pickle |
|
import whisper |
|
import streamlit as st |
|
import torchaudio as ta |
|
|
|
from io import BytesIO |
|
from transformers import AutoProcessor, SeamlessM4TModel, WhisperProcessor, WhisperForConditionalGeneration |
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda:0" |
|
torch_dtype = torch.float16 |
|
else: |
|
device = "cpu" |
|
torch_dtype = torch.float32 |
|
|
|
SAMPLING_RATE=16000 |
|
task = "transcribe" |
|
|
|
print(f"{device} Active!") |
|
|
|
|
|
processor = WhisperProcessor.from_pretrained("openai/whisper-small") |
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") |
|
|
|
|
|
st.title("Audio Player with Live Transcription") |
|
|
|
|
|
st.sidebar.header("Upload Audio Files") |
|
uploaded_files = st.sidebar.file_uploader("Choose audio files", type=["mp3", "wav"], accept_multiple_files=True) |
|
submit_button = st.sidebar.button("Submit") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_language(audio_file): |
|
whisper_model = whisper.load_model("base") |
|
mel = whisper.log_mel_spectrogram(trimmed_audio).to(whisper_model.device) |
|
|
|
_, probs = whisper_model.detect_language(mel) |
|
print(f"Detected language: {max(probs[0], key=probs[0].get)}") |
|
return max(probs[0], key=probs[0].get) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if submit_button and uploaded_files is not None: |
|
|
|
detected_languages = [] |
|
|
|
for uploaded_file in uploaded_files: |
|
|
|
waveform, sampling_rate = ta.load(BytesIO(uploaded_file.read())) |
|
|
|
|
|
if sampling_rate != SAMPLING_RATE: |
|
waveform = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE) |
|
|
|
|
|
detected_language = detect_language(waveform, SAMPLING_RATE) |
|
detected_languages.append(detected_language) |
|
|
|
|
|
for i, uploaded_file in enumerate(uploaded_files): |
|
col1, col2 = st.columns([1, 3]) |
|
|
|
with col1: |
|
st.write(f"**File name**: {uploaded_file.name}") |
|
st.audio(BytesIO(uploaded_file.getvalue()), format=uploaded_file.type) |
|
st.write(f"**Detected Language**: {detected_languages[i]}") |
|
|
|
with col2: |
|
|
|
if st.button(f"Transcribe {uploaded_file.name}"): |
|
|
|
input_features = processor(waveform[0], sampling_rate=SAMPLING_RATE, return_tensors='pt').input_features |
|
predicted_ids = model.generate(input_features) |
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) |
|
for line in transcription: |
|
st.write(line) |
|
|
|
if st.button(f"Translate {uploaded_file.name}"): |
|
|
|
with open('languages.pkl', 'rb') as f: |
|
lang_dict = pickle.load(f) |
|
detected_language_name = lang_dict[detected_languages[i]] |
|
|
|
forced_decoder_ids = processor.get_decoder_prompt_ids(language=detected_language_name, task="translate") |
|
predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids) |
|
translation = processor.batch_decode(predicted_ids, skip_special_tokens=True) |
|
for line in translation: |
|
st.write(line) |
|
|