import torch import pickle import whisper import streamlit as st import torchaudio as ta from io import BytesIO from transformers import AutoProcessor, SeamlessM4TModel, WhisperProcessor, WhisperForConditionalGeneration if torch.cuda.is_available(): device = "cuda:0" torch_dtype = torch.float16 else: device = "cpu" torch_dtype = torch.float32 SAMPLING_RATE=16000 task = "transcribe" print(f"{device} Active!") # load Whisper model and processor processor = WhisperProcessor.from_pretrained("openai/whisper-small") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") # Title of the app st.title("Audio Player with Live Transcription") # Sidebar for file uploader and submit button st.sidebar.header("Upload Audio Files") uploaded_files = st.sidebar.file_uploader("Choose audio files", type=["mp3", "wav"], accept_multiple_files=True) submit_button = st.sidebar.button("Submit") # def transcribe_audio(audio_data): # recognizer = sr.Recognizer() # with sr.AudioFile(audio_data) as source: # audio = recognizer.record(source) # try: # # Transcribe the audio using Google Web Speech API # transcription = recognizer.recognize_google(audio) # return transcription # except sr.UnknownValueError: # return "Unable to transcribe the audio." # except sr.RequestError as e: # return f"Could not request results; {e}" def detect_language(audio_file): whisper_model = whisper.load_model("base") mel = whisper.log_mel_spectrogram(trimmed_audio).to(whisper_model.device) # detect the spoken language _, probs = whisper_model.detect_language(mel) print(f"Detected language: {max(probs[0], key=probs[0].get)}") return max(probs[0], key=probs[0].get) # if submit_button and uploaded_files is not None: # st.write("Files uploaded successfully!") # for uploaded_file in uploaded_files: # # Display file name and audio player # st.write(f"**File name**: {uploaded_file.name}") # st.audio(uploaded_file, format=uploaded_file.type) # # Transcription section # st.write("**Transcription**:") # # Read the uploaded file data # waveform, sampling_rate = ta.load(uploaded_file.getvalue()) # resampled_inp = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE) # input_features = processor(resampled_inp[0], sampling_rate=16000, return_tensors='pt').input_features # if task == "translate": # # Detect Language # lang = detect_language(input_features) # with open('languages.pkl', 'rb') as f: # lang_dict = pickle.load(f) # detected_language = lang_dict[lang] # # Set decoder & Predict translation # forced_decoder_ids = processor.get_decoder_prompt_ids(language=detected_language, task="translate") # predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids) # else: # predicted_ids = model.generate(input_features) # # decode token ids to text # transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) # for i in range(len(transcription)): # st.write(transcription[i]) # # print(waveform, sampling_rate) # # Run transcription function and display # # import pdb;pdb.set_trace() # # st.write(audio_data.getvalue()) if submit_button and uploaded_files is not None: # Initialize a list to store detected languages detected_languages = [] for uploaded_file in uploaded_files: # Read the uploaded file data waveform, sampling_rate = ta.load(BytesIO(uploaded_file.read())) # Resample if necessary if sampling_rate != SAMPLING_RATE: waveform = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE) # Detect language detected_language = detect_language(waveform, SAMPLING_RATE) detected_languages.append(detected_language) # Display each uploaded file with its detected language and an audio player for i, uploaded_file in enumerate(uploaded_files): col1, col2 = st.columns([1, 3]) # Two columns, one for the player, one for the buttons with col1: st.write(f"**File name**: {uploaded_file.name}") st.audio(BytesIO(uploaded_file.getvalue()), format=uploaded_file.type) st.write(f"**Detected Language**: {detected_languages[i]}") with col2: # Add Transcription and Translation buttons if st.button(f"Transcribe {uploaded_file.name}"): # Transcription process input_features = processor(waveform[0], sampling_rate=SAMPLING_RATE, return_tensors='pt').input_features predicted_ids = model.generate(input_features) transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) for line in transcription: st.write(line) if st.button(f"Translate {uploaded_file.name}"): # Translation process with open('languages.pkl', 'rb') as f: lang_dict = pickle.load(f) detected_language_name = lang_dict[detected_languages[i]] forced_decoder_ids = processor.get_decoder_prompt_ids(language=detected_language_name, task="translate") predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids) translation = processor.batch_decode(predicted_ids, skip_special_tokens=True) for line in translation: st.write(line)