Spaces:
Running
Running
import streamlit as st | |
import tensorflow as tf | |
import numpy as np | |
import librosa | |
import matplotlib.pyplot as plt | |
import librosa.display | |
import tempfile | |
import os | |
# Load the trained model #test | |
def load_model(): | |
model_path = "sound_classification_model.h5" # Replace with the path to your .h5 file | |
model = tf.keras.models.load_model(model_path) | |
return model | |
model = load_model() | |
# Map Class Labels | |
CLASS_LABELS = { | |
0: 'Air Conditioner', | |
1: 'Car Horn', | |
2: 'Children Playing', | |
3: 'Dog Bark', | |
4: 'Drilling', | |
5: 'Engine Idling', | |
6: 'Gun Shot', | |
7: 'Jackhammer', | |
8: 'Siren', | |
9: 'Street Music' | |
} | |
# Preprocess audio into a spectrogram | |
def preprocess_audio(file_path, n_mels=128, fixed_time_steps=128): | |
try: | |
y, sr = librosa.load(file_path, sr=None) | |
mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels, fmax=sr / 2) | |
log_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max) | |
log_spectrogram = log_spectrogram / np.max(np.abs(log_spectrogram)) | |
if log_spectrogram.shape[1] < fixed_time_steps: | |
padding = fixed_time_steps - log_spectrogram.shape[1] | |
log_spectrogram = np.pad(log_spectrogram, ((0, 0), (0, padding)), mode='constant') | |
else: | |
log_spectrogram = log_spectrogram[:, :fixed_time_steps] | |
return np.expand_dims(log_spectrogram, axis=-1) # Add channel dimension for CNNs | |
except Exception as e: | |
print(f"Error processing {file_path}: {e}") | |
return None | |
# Streamlit app UI | |
st.title("Audio Spectrogram Prediction") | |
st.write("Upload an audio file to generate a spectrogram and predict its class using your trained model.") | |
# File upload widget | |
uploaded_file = st.file_uploader("Choose an audio file", type=["wav", "mp3"]) | |
if uploaded_file is not None: | |
# Save the uploaded audio file to a temporary location | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file: | |
temp_audio_file.write(uploaded_file.read()) | |
temp_audio_path = temp_audio_file.name | |
# Preprocess the audio into a spectrogram | |
st.write("Processing audio into a spectrogram...") | |
spectrogram = preprocess_audio(temp_audio_path) | |
if spectrogram is not None: | |
# Display the spectrogram | |
st.write("Generated Spectrogram:") | |
plt.figure(figsize=(10, 4)) | |
librosa.display.specshow(spectrogram[:, :, 0], sr=22050, x_axis='time', y_axis='mel', fmax=8000, cmap='plasma') | |
plt.colorbar(format='%+2.0f dB') | |
plt.title('Mel-Spectrogram') | |
plt.tight_layout() | |
st.pyplot(plt) | |
# Predict using the model | |
st.write("Predicting...") | |
spectrogram = np.expand_dims(spectrogram, axis=0) # Add batch dimension | |
predictions = model.predict(spectrogram) | |
predicted_class_index = np.argmax(predictions, axis=-1)[0] | |
predicted_class_label = CLASS_LABELS.get(predicted_class_index, "Unknown") | |
# Display the results | |
st.write("Prediction Results:") | |
st.write(f"**Predicted Class:** {predicted_class_label} (Index: {predicted_class_index})") | |
st.write(f"**Raw Model Output:** {predictions}") | |
else: | |
st.write("Failed to process the audio file. Please try again with a different file.") | |
# Optional: Clean up temporary file | |
os.remove(temp_audio_path) | |
# st.write("### Developer Team") | |
# developer_info = [ | |
# {"name": "Faheyra", "image_url": "https://italeemc.iium.edu.my/pluginfile.php/21200/user/icon/remui/f3?rev=40826", "title": "MLetops Engineer"}, | |
# {"name": "Adilah", "image_url": "https://italeemc.iium.edu.my/pluginfile.php/21229/user/icon/remui/f3?rev=43498", "title": "Ra-Sis-Chear"}, | |
# {"name": "Aida", "image_url": "https://italeemc.iium.edu.my/pluginfile.php/21236/user/icon/remui/f3?rev=43918", "title": "Ra-Sis-Chear"}, | |
# {"name": "Naufal", "image_url": "https://italeemc.iium.edu.my/pluginfile.php/21260/user/icon/remui/f3?rev=400622", "title": "Rizzichear"}, | |
# {"name": "Fadzwan", "image_url": "https://italeemc.iium.edu.my/pluginfile.php/21094/user/icon/remui/f3?rev=59457", "title": "Nasser"}, | |
# ] | |
# # Dynamically create columns based on the number of developers | |
# num_devs = len(developer_info) | |
# cols = st.columns(num_devs) | |
# # Display the developer profiles | |
# for idx, dev in enumerate(developer_info): | |
# col = cols[idx] | |
# with col: | |
# st.markdown( | |
# f'<div style="display: flex; flex-direction: column; align-items: center;">' | |
# f'<img src="{dev["image_url"]}" width="100" style="border-radius: 50%;">' | |
# f'<p>{dev["name"]}<br>{dev["title"]}</p>' | |
# f'<p></p>' | |
# f'</div>', | |
# unsafe_allow_html=True | |
# ) | |