Huma10's picture
Update app.py
ae11109 verified
import streamlit as st
import pandas as pd
import numpy as np
import librosa.display
import matplotlib.pyplot as plt
import plotly.express as px
import torch
import torchaudio
import time
from transformers import WhisperForAudioClassification, AutoFeatureExtractor
# Set page title and favicon
#st.set_page_config(page_title="Audio Visualization", page_icon="🎧")
# Upload audio file
audio_file = st.file_uploader("Upload Audio file for Assessment", type=["wav", "mp3"])
# Load the model and processor
model = WhisperForAudioClassification.from_pretrained("Huma10/Whisper_Stuttered_Speech")
feature_extractor = AutoFeatureExtractor.from_pretrained("Huma10/Whisper_Stuttered_Speech")
total_inference_time = 0 # Initialize the total inference time
# Check if an audio file is uploaded
if audio_file is not None:
st.audio(audio_file, format="audio/wav")
# Load and preprocess the uploaded audio file
input_audio, _ = torchaudio.load(audio_file)
# Save the filename
audio_filename = audio_file.name
# Segment the audio into 3-second clips
target_duration = 3 # 3 seconds
target_samples = int(target_duration * 16000)
num_clips = input_audio.size(1) // target_samples
audio_clips = [input_audio[:, i * target_samples : (i + 1) * target_samples] for i in range(num_clips)]
predicted_labels_list = []
# Perform inference for each clip
for clip in audio_clips:
inputs = feature_extractor(clip.squeeze().numpy(), return_tensors="pt")
input_features = inputs.input_features
# Ensure input features have the required length of 3000
if input_features.shape[-1] < 3000:
pad_length = 3000 - input_features.shape[-1]
input_features = torch.nn.functional.pad(input_features, (0, pad_length), mode='constant', value=0)
elif input_features.shape[-1] > 3000:
input_features = input_features[:, :, :3000]
# Measure inference time
start_time = time.time()
# Perform inference
with torch.no_grad():
logits = model(input_features).logits
end_time = time.time()
inference_time = end_time - start_time
total_inference_time += inference_time # Accumulate inference time
# Convert logits to predictions
predicted_class_ids = torch.argmax(logits, dim=-1)
predicted_labels = [model.config.id2label[class_id.item()] for class_id in predicted_class_ids]
predicted_labels_list.extend(predicted_labels)
st.markdown(f"Total inference time: **{total_inference_time:.4f}** seconds")
def calculate_percentages(predicted_labels):
# Count each type of disfluency
disfluency_count = pd.Series(predicted_labels).value_counts(normalize=True)
return disfluency_count * 100 # Convert fractions to percentages
def plot_disfluency_percentages(percentages):
fig, ax = plt.subplots()
percentages.plot(kind='bar', ax=ax, color='#70bdbd')
ax.set_title('Percentage of Each Disfluency Type')
ax.set_xlabel('Disfluency Type')
ax.set_ylabel('Percentage')
plt.xticks(rotation=45)
return fig
# Streamlit application
def main():
st.title("Speech Profile")
st.write("This app analyzes the percentage of different types of disfluencies in stuttered speech.")
# Calculate percentages
percentages = calculate_percentages(predicted_labels_list)
# Plot
fig = plot_disfluency_percentages(percentages)
st.pyplot(fig)
main()
success_check = st.success('Assessment Completed Successfully!', icon="βœ…")
time.sleep(5)
success_check = st.empty()