Huma10 commited on
Commit
1b86db7
Β·
verified Β·
1 Parent(s): 5dbbb0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -9,7 +9,6 @@ import torchaudio
9
  import time
10
  from transformers import WhisperForAudioClassification, AutoFeatureExtractor
11
 
12
-
13
  # Set page title and favicon
14
  st.set_page_config(page_title="Audio Visualization", page_icon="🎧")
15
 
@@ -20,6 +19,7 @@ audio_file = st.file_uploader("Upload Audio file for Assessment", type=["wav", "
20
  model = WhisperForAudioClassification.from_pretrained("Huma10/Whisper_Stuttered_Speech")
21
  feature_extractor = AutoFeatureExtractor.from_pretrained("Huma10/Whisper_Stuttered_Speech")
22
  total_inference_time = 0 # Initialize the total inference time
 
23
  # Check if an audio file is uploaded
24
  if audio_file is not None:
25
  st.audio(audio_file, format="audio/wav")
@@ -39,8 +39,12 @@ if audio_file is not None:
39
  for clip in audio_clips:
40
  inputs = feature_extractor(clip.squeeze().numpy(), return_tensors="pt")
41
  input_features = inputs.input_features
42
-
43
 
 
 
 
 
 
44
  # Measure inference time
45
  start_time = time.time()
46
  # Perform inference
@@ -57,10 +61,11 @@ if audio_file is not None:
57
  predicted_labels_list.extend(predicted_labels)
58
 
59
  st.markdown(f"Total inference time: **{total_inference_time:.4f}** seconds")
 
60
  def calculate_percentages(predicted_labels):
61
- # Count each type of disfluency
62
- disfluency_count = pd.Series(predicted_labels).value_counts(normalize=True)
63
- return disfluency_count * 100 # Convert fractions to percentages
64
 
65
  def plot_disfluency_percentages(percentages):
66
  fig, ax = plt.subplots()
@@ -71,7 +76,7 @@ if audio_file is not None:
71
  plt.xticks(rotation=45)
72
  return fig
73
 
74
- # Streamlit application
75
  def main():
76
  st.title("Speech Profile")
77
  st.write("This app analyzes the percentage of different types of disfluencies in stuttered speech.")
@@ -83,12 +88,8 @@ if audio_file is not None:
83
  fig = plot_disfluency_percentages(percentages)
84
  st.pyplot(fig)
85
 
86
-
87
  main()
88
 
89
- success_check=st.success(' Assessment Completed Successfully!', icon="βœ…")
90
  time.sleep(5)
91
- success_check=st.empty()
92
-
93
-
94
-
 
9
  import time
10
  from transformers import WhisperForAudioClassification, AutoFeatureExtractor
11
 
 
12
  # Set page title and favicon
13
  st.set_page_config(page_title="Audio Visualization", page_icon="🎧")
14
 
 
19
  model = WhisperForAudioClassification.from_pretrained("Huma10/Whisper_Stuttered_Speech")
20
  feature_extractor = AutoFeatureExtractor.from_pretrained("Huma10/Whisper_Stuttered_Speech")
21
  total_inference_time = 0 # Initialize the total inference time
22
+
23
  # Check if an audio file is uploaded
24
  if audio_file is not None:
25
  st.audio(audio_file, format="audio/wav")
 
39
  for clip in audio_clips:
40
  inputs = feature_extractor(clip.squeeze().numpy(), return_tensors="pt")
41
  input_features = inputs.input_features
 
42
 
43
+ # Pad input features to length 3000
44
+ if input_features.shape[-1] < 3000:
45
+ pad_length = 3000 - input_features.shape[-1]
46
+ input_features = torch.nn.functional.pad(input_features, (0, pad_length))
47
+
48
  # Measure inference time
49
  start_time = time.time()
50
  # Perform inference
 
61
  predicted_labels_list.extend(predicted_labels)
62
 
63
  st.markdown(f"Total inference time: **{total_inference_time:.4f}** seconds")
64
+
65
  def calculate_percentages(predicted_labels):
66
+ # Count each type of disfluency
67
+ disfluency_count = pd.Series(predicted_labels).value_counts(normalize=True)
68
+ return disfluency_count * 100 # Convert fractions to percentages
69
 
70
  def plot_disfluency_percentages(percentages):
71
  fig, ax = plt.subplots()
 
76
  plt.xticks(rotation=45)
77
  return fig
78
 
79
+ # Streamlit application
80
  def main():
81
  st.title("Speech Profile")
82
  st.write("This app analyzes the percentage of different types of disfluencies in stuttered speech.")
 
88
  fig = plot_disfluency_percentages(percentages)
89
  st.pyplot(fig)
90
 
 
91
  main()
92
 
93
+ success_check = st.success('Assessment Completed Successfully!', icon="βœ…")
94
  time.sleep(5)
95
+ success_check = st.empty()