File size: 3,725 Bytes
5dbbb0c
 
 
 
 
 
 
 
 
 
 
 
ae11109
5dbbb0c
 
 
 
 
 
 
 
1b86db7
5dbbb0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45eb876
1b86db7
 
45eb876
 
 
1b86db7
5dbbb0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b86db7
5dbbb0c
1b86db7
 
 
5dbbb0c
 
 
 
 
 
 
 
 
 
1b86db7
5dbbb0c
 
 
 
 
 
 
 
 
 
 
 
 
1b86db7
5dbbb0c
1b86db7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
import pandas as pd
import numpy as np
import librosa.display
import matplotlib.pyplot as plt
import plotly.express as px
import torch
import torchaudio
import time
from transformers import WhisperForAudioClassification, AutoFeatureExtractor

# Set page title and favicon
#st.set_page_config(page_title="Audio Visualization", page_icon="🎧")

# Upload audio file
audio_file = st.file_uploader("Upload Audio file for Assessment", type=["wav", "mp3"])

# Load the model and processor
model = WhisperForAudioClassification.from_pretrained("Huma10/Whisper_Stuttered_Speech")
feature_extractor = AutoFeatureExtractor.from_pretrained("Huma10/Whisper_Stuttered_Speech")
total_inference_time = 0  # Initialize the total inference time

# Check if an audio file is uploaded
if audio_file is not None:
    st.audio(audio_file, format="audio/wav")
    # Load and preprocess the uploaded audio file
    input_audio, _ = torchaudio.load(audio_file)
    # Save the filename
    audio_filename = audio_file.name
    # Segment the audio into 3-second clips
    target_duration = 3  # 3 seconds
    target_samples = int(target_duration * 16000)
    num_clips = input_audio.size(1) // target_samples
    audio_clips = [input_audio[:, i * target_samples : (i + 1) * target_samples] for i in range(num_clips)]

    predicted_labels_list = []

    # Perform inference for each clip
    for clip in audio_clips:
        inputs = feature_extractor(clip.squeeze().numpy(), return_tensors="pt")
        input_features = inputs.input_features
        
        # Ensure input features have the required length of 3000
        if input_features.shape[-1] < 3000:
            pad_length = 3000 - input_features.shape[-1]
            input_features = torch.nn.functional.pad(input_features, (0, pad_length), mode='constant', value=0)
        elif input_features.shape[-1] > 3000:
            input_features = input_features[:, :, :3000]

        # Measure inference time
        start_time = time.time()
        # Perform inference
        with torch.no_grad():
            logits = model(input_features).logits

        end_time = time.time()
        inference_time = end_time - start_time
        total_inference_time += inference_time  # Accumulate inference time

        # Convert logits to predictions
        predicted_class_ids = torch.argmax(logits, dim=-1)
        predicted_labels = [model.config.id2label[class_id.item()] for class_id in predicted_class_ids]
        predicted_labels_list.extend(predicted_labels)
    
    st.markdown(f"Total inference time: **{total_inference_time:.4f}** seconds")
    
    def calculate_percentages(predicted_labels):
        # Count each type of disfluency
        disfluency_count = pd.Series(predicted_labels).value_counts(normalize=True)
        return disfluency_count * 100  # Convert fractions to percentages

    def plot_disfluency_percentages(percentages):
        fig, ax = plt.subplots()
        percentages.plot(kind='bar', ax=ax, color='#70bdbd')
        ax.set_title('Percentage of Each Disfluency Type')
        ax.set_xlabel('Disfluency Type')
        ax.set_ylabel('Percentage')
        plt.xticks(rotation=45)
        return fig

    # Streamlit application
    def main():
        st.title("Speech Profile")
        st.write("This app analyzes the percentage of different types of disfluencies in stuttered speech.")

        # Calculate percentages
        percentages = calculate_percentages(predicted_labels_list)
        
        # Plot
        fig = plot_disfluency_percentages(percentages)
        st.pyplot(fig)

    main()

    success_check = st.success('Assessment Completed Successfully!', icon="βœ…")
    time.sleep(5)
    success_check = st.empty()