Huma10 commited on
Commit
5dbbb0c
Β·
verified Β·
1 Parent(s): 1c53dc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -96
app.py CHANGED
@@ -1,97 +1,94 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import numpy as np
4
- import librosa.display
5
- import matplotlib.pyplot as plt
6
- import plotly.express as px
7
- from streamlit_extras.colored_header import colored_header
8
- import torch
9
- import torchaudio
10
- import time
11
- from transformers import WhisperForAudioClassification, AutoFeatureExtractor
12
- from streamlit_option_menu import option_menu
13
- import matplotlib.colors as mcolors
14
-
15
-
16
- # Set page title and favicon
17
- st.set_page_config(page_title="Audio Visualization", page_icon="🎧")
18
-
19
- # Upload audio file
20
- audio_file = st.file_uploader("Upload Audio file for Assessment", type=["wav", "mp3"])
21
-
22
- # Load the model and processor
23
- model = WhisperForAudioClassification.from_pretrained("Huma10/Whisper_Stuttered_Speech")
24
- feature_extractor = AutoFeatureExtractor.from_pretrained("Huma10/Whisper_Stuttered_Speech")
25
- total_inference_time = 0 # Initialize the total inference time
26
- # Check if an audio file is uploaded
27
- if audio_file is not None:
28
- st.audio(audio_file, format="audio/wav")
29
- # Load and preprocess the uploaded audio file
30
- input_audio, _ = torchaudio.load(audio_file)
31
- # Save the filename
32
- audio_filename = audio_file.name
33
- # Segment the audio into 3-second clips
34
- target_duration = 3 # 3 seconds
35
- target_samples = int(target_duration * 16000)
36
- num_clips = input_audio.size(1) // target_samples
37
- audio_clips = [input_audio[:, i * target_samples : (i + 1) * target_samples] for i in range(num_clips)]
38
-
39
- predicted_labels_list = []
40
-
41
- # Perform inference for each clip
42
- for clip in audio_clips:
43
- inputs = feature_extractor(clip.squeeze().numpy(), return_tensors="pt")
44
- input_features = inputs.input_features
45
-
46
-
47
- # Measure inference time
48
- start_time = time.time()
49
- # Perform inference
50
- with torch.no_grad():
51
- logits = model(input_features).logits
52
-
53
- end_time = time.time()
54
- inference_time = end_time - start_time
55
- total_inference_time += inference_time # Accumulate inference time
56
-
57
- # Convert logits to predictions
58
- predicted_class_ids = torch.argmax(logits, dim=-1)
59
- predicted_labels = [model.config.id2label[class_id.item()] for class_id in predicted_class_ids]
60
- predicted_labels_list.extend(predicted_labels)
61
-
62
- st.markdown(f"Total inference time: **{total_inference_time:.4f}** seconds")
63
- def calculate_percentages(predicted_labels):
64
- # Count each type of disfluency
65
- disfluency_count = pd.Series(predicted_labels).value_counts(normalize=True)
66
- return disfluency_count * 100 # Convert fractions to percentages
67
-
68
- def plot_disfluency_percentages(percentages):
69
- fig, ax = plt.subplots()
70
- percentages.plot(kind='bar', ax=ax, color='#70bdbd')
71
- ax.set_title('Percentage of Each Disfluency Type')
72
- ax.set_xlabel('Disfluency Type')
73
- ax.set_ylabel('Percentage')
74
- plt.xticks(rotation=45)
75
- return fig
76
-
77
- # Streamlit application
78
- def main():
79
- st.title("Speech Profile")
80
- st.write("This app analyzes the percentage of different types of disfluencies in stuttered speech.")
81
-
82
- # Calculate percentages
83
- percentages = calculate_percentages(predicted_labels_list)
84
-
85
- # Plot
86
- fig = plot_disfluency_percentages(percentages)
87
- st.pyplot(fig)
88
-
89
-
90
- main()
91
-
92
- success_check=st.success(' Assessment Completed Successfully!', icon="βœ…")
93
- time.sleep(5)
94
- success_check=st.empty()
95
-
96
-
97
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import librosa.display
5
+ import matplotlib.pyplot as plt
6
+ import plotly.express as px
7
+ import torch
8
+ import torchaudio
9
+ import time
10
+ from transformers import WhisperForAudioClassification, AutoFeatureExtractor
11
+
12
+
13
+ # Set page title and favicon
14
+ st.set_page_config(page_title="Audio Visualization", page_icon="🎧")
15
+
16
+ # Upload audio file
17
+ audio_file = st.file_uploader("Upload Audio file for Assessment", type=["wav", "mp3"])
18
+
19
+ # Load the model and processor
20
+ model = WhisperForAudioClassification.from_pretrained("Huma10/Whisper_Stuttered_Speech")
21
+ feature_extractor = AutoFeatureExtractor.from_pretrained("Huma10/Whisper_Stuttered_Speech")
22
+ total_inference_time = 0 # Initialize the total inference time
23
+ # Check if an audio file is uploaded
24
+ if audio_file is not None:
25
+ st.audio(audio_file, format="audio/wav")
26
+ # Load and preprocess the uploaded audio file
27
+ input_audio, _ = torchaudio.load(audio_file)
28
+ # Save the filename
29
+ audio_filename = audio_file.name
30
+ # Segment the audio into 3-second clips
31
+ target_duration = 3 # 3 seconds
32
+ target_samples = int(target_duration * 16000)
33
+ num_clips = input_audio.size(1) // target_samples
34
+ audio_clips = [input_audio[:, i * target_samples : (i + 1) * target_samples] for i in range(num_clips)]
35
+
36
+ predicted_labels_list = []
37
+
38
+ # Perform inference for each clip
39
+ for clip in audio_clips:
40
+ inputs = feature_extractor(clip.squeeze().numpy(), return_tensors="pt")
41
+ input_features = inputs.input_features
42
+
43
+
44
+ # Measure inference time
45
+ start_time = time.time()
46
+ # Perform inference
47
+ with torch.no_grad():
48
+ logits = model(input_features).logits
49
+
50
+ end_time = time.time()
51
+ inference_time = end_time - start_time
52
+ total_inference_time += inference_time # Accumulate inference time
53
+
54
+ # Convert logits to predictions
55
+ predicted_class_ids = torch.argmax(logits, dim=-1)
56
+ predicted_labels = [model.config.id2label[class_id.item()] for class_id in predicted_class_ids]
57
+ predicted_labels_list.extend(predicted_labels)
58
+
59
+ st.markdown(f"Total inference time: **{total_inference_time:.4f}** seconds")
60
+ def calculate_percentages(predicted_labels):
61
+ # Count each type of disfluency
62
+ disfluency_count = pd.Series(predicted_labels).value_counts(normalize=True)
63
+ return disfluency_count * 100 # Convert fractions to percentages
64
+
65
+ def plot_disfluency_percentages(percentages):
66
+ fig, ax = plt.subplots()
67
+ percentages.plot(kind='bar', ax=ax, color='#70bdbd')
68
+ ax.set_title('Percentage of Each Disfluency Type')
69
+ ax.set_xlabel('Disfluency Type')
70
+ ax.set_ylabel('Percentage')
71
+ plt.xticks(rotation=45)
72
+ return fig
73
+
74
+ # Streamlit application
75
+ def main():
76
+ st.title("Speech Profile")
77
+ st.write("This app analyzes the percentage of different types of disfluencies in stuttered speech.")
78
+
79
+ # Calculate percentages
80
+ percentages = calculate_percentages(predicted_labels_list)
81
+
82
+ # Plot
83
+ fig = plot_disfluency_percentages(percentages)
84
+ st.pyplot(fig)
85
+
86
+
87
+ main()
88
+
89
+ success_check=st.success(' Assessment Completed Successfully!', icon="βœ…")
90
+ time.sleep(5)
91
+ success_check=st.empty()
92
+
93
+
 
 
 
94