haoheliu's picture
Update app.py
c2747d4 verified
raw
history blame
3.95 kB
import streamlit as st
import torch
import os
import librosa
import librosa.display
import matplotlib.pyplot as plt
from audiosr import build_model, super_resolution, save_wave
import tempfile
import numpy as np
# Set MPS device if available (for Mac M-Series GPUs)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Title and Description
st.title("AudioSR: Versatile Audio Super-Resolution")
st.write("""
Upload your low-resolution audio files, and AudioSR will enhance them to high fidelity!
Supports all types of audio (music, speech, sound effects, etc.) with arbitrary sampling rates.
Only the first 10 seconds of the audio will be processed.
""")
# Upload audio file
uploaded_file = st.file_uploader("Upload an audio file (WAV format)", type=["wav"])
# Model Parameters
st.sidebar.title("Model Parameters")
model_name = st.sidebar.selectbox("Select Model", ["basic", "speech"], index=0)
ddim_steps = st.sidebar.slider("DDIM Steps", min_value=10, max_value=100, value=50)
guidance_scale = st.sidebar.slider("Guidance Scale", min_value=1.0, max_value=10.0, value=3.5)
random_seed = st.sidebar.number_input("Random Seed", min_value=0, value=42, step=1)
latent_t_per_second = 12.8
# Helper function to plot spectrogram
def plot_spectrogram(audio_path, title):
y, sr = librosa.load(audio_path, sr=None)
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=sr // 2)
S_dB = librosa.power_to_db(S, ref=np.max)
plt.figure(figsize=(10, 4))
librosa.display.specshow(S_dB, sr=sr, x_axis='time', y_axis='mel', fmax=sr // 2, cmap='viridis')
plt.colorbar(format='%+2.0f dB')
plt.title(title)
plt.tight_layout()
return plt
# Process Button
if uploaded_file and st.button("Enhance Audio"):
st.write("Processing audio...")
# Create temp directory for saving files
with tempfile.TemporaryDirectory() as tmp_dir:
input_path = os.path.join(tmp_dir, "input.wav")
truncated_path = os.path.join(tmp_dir, "truncated.wav")
output_path = os.path.join(tmp_dir, "output.wav")
# Save uploaded file locally
with open(input_path, "wb") as f:
f.write(uploaded_file.read())
# Load and truncate the first 10 seconds
y, sr = librosa.load(input_path, sr=None)
max_samples = sr * 10 # First 10 seconds
y_truncated = y[:max_samples]
librosa.output.write_wav(truncated_path, y_truncated, sr)
# Plot truncated spectrogram
st.write("Truncated Input Audio Spectrogram (First 10 seconds):")
truncated_spectrogram = plot_spectrogram(truncated_path, title="Truncated Input Audio Spectrogram")
st.pyplot(truncated_spectrogram)
# Build and load the model
audiosr = build_model(model_name=model_name, device=device)
# Perform super-resolution
waveform = super_resolution(
audiosr,
truncated_path,
seed=random_seed,
guidance_scale=guidance_scale,
ddim_steps=ddim_steps,
latent_t_per_second=latent_t_per_second,
)
# Save enhanced audio
save_wave(waveform, inputpath=truncated_path, savepath=tmp_dir, name="output", samplerate=48000)
# Plot output spectrogram
st.write("Enhanced Audio Spectrogram:")
output_spectrogram = plot_spectrogram(output_path, title="Enhanced Audio Spectrogram")
st.pyplot(output_spectrogram)
# Display audio players and download link
st.audio(truncated_path, format="audio/wav")
st.write("Truncated Original Audio (First 10 seconds):")
st.audio(output_path, format="audio/wav")
st.write("Enhanced Audio:")
st.download_button("Download Enhanced Audio", data=open(output_path, "rb").read(), file_name="enhanced_audio.wav")
# Footer
st.write("Built with [Streamlit](https://streamlit.io) and [AudioSR](https://audioldm.github.io/audiosr)")