|
import torch |
|
import pickle |
|
import streamlit as st |
|
|
|
from io import BytesIO |
|
from config import SAMPLING_RATE |
|
from model_utils import load_models |
|
from audio_processing import detect_language, process_long_audio, load_and_resample_audio |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
load_models() |
|
|
|
|
|
st.title("Audio Player with Live Transcription and Translation") |
|
|
|
|
|
st.sidebar.header("Upload Audio Files") |
|
uploaded_files = st.sidebar.file_uploader("Choose audio files", type=["mp3", "wav"], accept_multiple_files=True) |
|
submit_button = st.sidebar.button("Submit") |
|
|
|
|
|
if 'audio_files' not in st.session_state: |
|
st.session_state.audio_files = [] |
|
st.session_state.transcriptions = {} |
|
st.session_state.translations = {} |
|
st.session_state.detected_languages = [] |
|
st.session_state.waveforms = [] |
|
|
|
|
|
if submit_button and uploaded_files is not None: |
|
st.session_state.audio_files = uploaded_files |
|
st.session_state.detected_languages = [] |
|
st.session_state.waveforms = [] |
|
|
|
for uploaded_file in uploaded_files: |
|
waveform = load_and_resample_audio(BytesIO(uploaded_file.read())) |
|
st.session_state.waveforms.append(waveform) |
|
detected_language = detect_language(waveform) |
|
st.session_state.detected_languages.append(detected_language) |
|
|
|
|
|
if 'audio_files' in st.session_state and st.session_state.audio_files: |
|
for i, uploaded_file in enumerate(st.session_state.audio_files): |
|
st.write(f"**File name**: {uploaded_file.name}") |
|
st.audio(uploaded_file, format=uploaded_file.type) |
|
st.write(f"**Detected Language**: {st.session_state.detected_languages[i]}") |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
if st.button(f"Transcribe {uploaded_file.name}"): |
|
with st.spinner("Transcribing..."): |
|
transcription = process_long_audio(st.session_state.waveforms[i], SAMPLING_RATE) |
|
st.session_state.transcriptions[i] = transcription |
|
|
|
if st.session_state.transcriptions.get(i): |
|
st.write("**Transcription**:") |
|
st.text_area("", st.session_state.transcriptions[i], height=200, key=f"transcription_{i}") |
|
st.markdown(f'<div style="text-align: right;"><a href="data:text/plain;charset=UTF-8,{st.session_state.transcriptions[i]}" download="transcription_{uploaded_file.name}.txt">Download Transcription</a></div>', unsafe_allow_html=True) |
|
|
|
with col2: |
|
if st.button(f"Translate {uploaded_file.name}"): |
|
with st.spinner("Translating..."): |
|
with open('languages.pkl', 'rb') as f: |
|
lang_dict = pickle.load(f) |
|
detected_language_name = lang_dict[st.session_state.detected_languages[i]] |
|
|
|
translation = process_long_audio(st.session_state.waveforms[i], SAMPLING_RATE, task="translate", |
|
language=detected_language_name) |
|
st.session_state.translations[i] = translation |
|
if st.session_state.translations.get(i): |
|
st.write("**Translation**:") |
|
st.text_area("", st.session_state.translations[i], height=200, key=f"translation_{i}") |
|
st.markdown(f'<div style="text-align: right;"><a href="data:text/plain;charset=UTF-8,{st.session_state.translations[i]}" download="translation_{uploaded_file.name}.txt">Download Translation</a></div>', unsafe_allow_html=True) |