Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import librosa.display | |
import matplotlib.pyplot as plt | |
import plotly.express as px | |
import torch | |
import torchaudio | |
import time | |
from transformers import WhisperForAudioClassification, AutoFeatureExtractor | |
# Set page title and favicon | |
#st.set_page_config(page_title="Audio Visualization", page_icon="π§") | |
# Upload audio file | |
audio_file = st.file_uploader("Upload Audio file for Assessment", type=["wav", "mp3"]) | |
# Load the model and processor | |
model = WhisperForAudioClassification.from_pretrained("Huma10/Whisper_Stuttered_Speech") | |
feature_extractor = AutoFeatureExtractor.from_pretrained("Huma10/Whisper_Stuttered_Speech") | |
total_inference_time = 0 # Initialize the total inference time | |
# Check if an audio file is uploaded | |
if audio_file is not None: | |
st.audio(audio_file, format="audio/wav") | |
# Load and preprocess the uploaded audio file | |
input_audio, _ = torchaudio.load(audio_file) | |
# Save the filename | |
audio_filename = audio_file.name | |
# Segment the audio into 3-second clips | |
target_duration = 3 # 3 seconds | |
target_samples = int(target_duration * 16000) | |
num_clips = input_audio.size(1) // target_samples | |
audio_clips = [input_audio[:, i * target_samples : (i + 1) * target_samples] for i in range(num_clips)] | |
predicted_labels_list = [] | |
# Perform inference for each clip | |
for clip in audio_clips: | |
inputs = feature_extractor(clip.squeeze().numpy(), return_tensors="pt") | |
input_features = inputs.input_features | |
# Ensure input features have the required length of 3000 | |
if input_features.shape[-1] < 3000: | |
pad_length = 3000 - input_features.shape[-1] | |
input_features = torch.nn.functional.pad(input_features, (0, pad_length), mode='constant', value=0) | |
elif input_features.shape[-1] > 3000: | |
input_features = input_features[:, :, :3000] | |
# Measure inference time | |
start_time = time.time() | |
# Perform inference | |
with torch.no_grad(): | |
logits = model(input_features).logits | |
end_time = time.time() | |
inference_time = end_time - start_time | |
total_inference_time += inference_time # Accumulate inference time | |
# Convert logits to predictions | |
predicted_class_ids = torch.argmax(logits, dim=-1) | |
predicted_labels = [model.config.id2label[class_id.item()] for class_id in predicted_class_ids] | |
predicted_labels_list.extend(predicted_labels) | |
st.markdown(f"Total inference time: **{total_inference_time:.4f}** seconds") | |
def calculate_percentages(predicted_labels): | |
# Count each type of disfluency | |
disfluency_count = pd.Series(predicted_labels).value_counts(normalize=True) | |
return disfluency_count * 100 # Convert fractions to percentages | |
def plot_disfluency_percentages(percentages): | |
fig, ax = plt.subplots() | |
percentages.plot(kind='bar', ax=ax, color='#70bdbd') | |
ax.set_title('Percentage of Each Disfluency Type') | |
ax.set_xlabel('Disfluency Type') | |
ax.set_ylabel('Percentage') | |
plt.xticks(rotation=45) | |
return fig | |
# Streamlit application | |
def main(): | |
st.title("Speech Profile") | |
st.write("This app analyzes the percentage of different types of disfluencies in stuttered speech.") | |
# Calculate percentages | |
percentages = calculate_percentages(predicted_labels_list) | |
# Plot | |
fig = plot_disfluency_percentages(percentages) | |
st.pyplot(fig) | |
main() | |
success_check = st.success('Assessment Completed Successfully!', icon="β ") | |
time.sleep(5) | |
success_check = st.empty() | |