Mohssinibra commited on
Commit
54472e1
·
verified ·
1 Parent(s): 557b689

14-02-24 10 25

Browse files
Files changed (1) hide show
  1. app.py +38 -71
app.py CHANGED
@@ -11,17 +11,18 @@ from pydub.utils import mediainfo
11
  from pydub.silence import detect_nonsilent # Correct import
12
  import pandas as pd
13
 
14
-
15
  hf_token = os.getenv('diarizationToken')
16
 
17
  print("Initializing Speech-to-Text Model...")
18
  stt_pipeline = pipeline("automatic-speech-recognition", model="boumehdi/wav2vec2-large-xlsr-moroccan-darija")
19
  print("Model Loaded Successfully.")
20
 
21
- # Initialize WhisperX with diarization (not transcription)
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
23
  diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token, device=device)
24
- print("WhisperX Model Loaded Successfully for Diarization.")
25
 
26
  def remove_phone_tonalities(audio, sr):
27
  nyquist = 0.5 * sr
@@ -31,18 +32,15 @@ def remove_phone_tonalities(audio, sr):
31
  filtered_audio = signal.filtfilt(b, a, audio)
32
  return filtered_audio
33
 
34
-
35
-
36
  def convert_audio_to_wav(audio_path):
37
- # Check the audio file format before conversion
38
  audio_info = mediainfo(audio_path)
39
  print(f"Audio file info: {audio_info}")
40
-
41
- if audio_info['format_name'] not in ['wav', 'mp3', 'flac', 'ogg']: # Add other valid formats if necessary
42
  raise ValueError(f"Unsupported audio format: {audio_info['format_name']}")
43
-
44
  try:
45
- # Convert any audio format to WAV using pydub
46
  sound = AudioSegment.from_file(audio_path)
47
  wav_path = "converted_audio.wav"
48
  sound.export(wav_path, format="wav")
@@ -51,88 +49,57 @@ def convert_audio_to_wav(audio_path):
51
  print(f"Error converting audio: {e}")
52
  raise
53
 
54
-
55
-
56
-
57
- hf_token = os.getenv('diarizationToken')
58
-
59
- print("Initializing Speech-to-Text Model...")
60
- stt_pipeline = pipeline("automatic-speech-recognition", model="boumehdi/wav2vec2-large-xlsr-moroccan-darija")
61
- print("Model Loaded Successfully.")
62
-
63
- # Initialize WhisperX with diarization
64
- device = "cuda" if torch.cuda.is_available() else "cpu"
65
- whisper_model = whisperx.load_model("large-v2", device)
66
- diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token, device=device)
67
- print("WhisperX Model Loaded Successfully.")
68
-
69
- def remove_phone_tonalities(audio, sr):
70
- nyquist = 0.5 * sr
71
- low_cut = 300 / nyquist
72
- high_cut = 3400 / nyquist
73
- b, a = signal.butter(1, [low_cut, high_cut], btype='band')
74
- filtered_audio = signal.filtfilt(b, a, audio)
75
- return filtered_audio
76
-
77
-
78
-
79
  def process_audio(audio_path):
 
80
  print(f"Received audio file: {audio_path}")
81
-
82
  try:
83
- # Load the audio file using librosa
84
  audio, sr = librosa.load(audio_path, sr=None, duration=30)
85
  print(f"Audio loaded: {len(audio)} samples at {sr} Hz")
86
 
87
- # Remove phone tonalities (if any)
88
  audio = remove_phone_tonalities(audio, sr)
89
  print("Phone tonalities removed")
90
-
91
  # Convert to AudioSegment for silence detection
92
  sound = AudioSegment.from_wav(audio_path)
93
-
94
- # Silence detection: split based on silence
95
- min_silence_len = 1000 # minimum silence length in ms
96
- silence_thresh = sound.dBFS - 14 # threshold for silence (adjust as needed)
97
-
98
- # Correct usage of detect_nonsilent from pydub.silence
99
- nonsilent_chunks = detect_nonsilent(
100
- sound,
101
- min_silence_len=min_silence_len,
102
- silence_thresh=silence_thresh
103
- )
104
-
105
- non_silent_chunks = [
106
- sound[start:end] for start, end in nonsilent_chunks
107
- ]
108
-
109
- # Apply diarization (WhisperX)
110
  diarization = diarize_model(audio_path)
111
-
112
- # Check if diarization is a DataFrame and process accordingly
113
  if isinstance(diarization, pd.DataFrame):
114
- print("Diarization is a DataFrame")
115
- diarization = diarization.to_dict(orient="records") # Convert DataFrame to a list of dicts
116
-
117
  transcriptions = []
118
- for chunk in non_silent_chunks:
 
 
119
  chunk.export("chunk.wav", format="wav")
 
 
 
 
120
  chunk_audio, chunk_sr = librosa.load("chunk.wav", sr=None)
121
- transcription = stt_pipeline(chunk_audio) # Transcribe using Wav2Vec2
122
-
123
- # Match transcription segment with diarization result
124
  speaker_label = "Unknown"
125
  for speaker in diarization:
126
  spk_start, spk_end, label = speaker['start'], speaker['end'], speaker['label']
127
- # Adjust timestamp matching
128
- if spk_start <= (chunk.start_time / 1000) <= spk_end: # Convert ms to seconds
129
  speaker_label = label
130
  break
131
-
132
  transcriptions.append(f"Speaker {speaker_label}: {transcription['text']}")
133
-
134
- # Clean up temporary files
135
- os.remove("chunk.wav")
136
 
137
  return "\n".join(transcriptions)
138
 
 
11
  from pydub.silence import detect_nonsilent # Correct import
12
  import pandas as pd
13
 
14
+ # Load Hugging Face token
15
  hf_token = os.getenv('diarizationToken')
16
 
17
  print("Initializing Speech-to-Text Model...")
18
  stt_pipeline = pipeline("automatic-speech-recognition", model="boumehdi/wav2vec2-large-xlsr-moroccan-darija")
19
  print("Model Loaded Successfully.")
20
 
21
+ # Initialize WhisperX with diarization
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ whisper_model = whisperx.load_model("large-v2", device)
24
  diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token, device=device)
25
+ print("WhisperX Model Loaded Successfully.")
26
 
27
  def remove_phone_tonalities(audio, sr):
28
  nyquist = 0.5 * sr
 
32
  filtered_audio = signal.filtfilt(b, a, audio)
33
  return filtered_audio
34
 
 
 
35
  def convert_audio_to_wav(audio_path):
36
+ """ Convert any supported audio format to WAV. """
37
  audio_info = mediainfo(audio_path)
38
  print(f"Audio file info: {audio_info}")
39
+
40
+ if audio_info['format_name'] not in ['wav', 'mp3', 'flac', 'ogg']:
41
  raise ValueError(f"Unsupported audio format: {audio_info['format_name']}")
42
+
43
  try:
 
44
  sound = AudioSegment.from_file(audio_path)
45
  wav_path = "converted_audio.wav"
46
  sound.export(wav_path, format="wav")
 
49
  print(f"Error converting audio: {e}")
50
  raise
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def process_audio(audio_path):
53
+ """ Process the audio: remove noise, split, diarize, and transcribe. """
54
  print(f"Received audio file: {audio_path}")
55
+
56
  try:
57
+ # Load the audio file
58
  audio, sr = librosa.load(audio_path, sr=None, duration=30)
59
  print(f"Audio loaded: {len(audio)} samples at {sr} Hz")
60
 
61
+ # Remove phone tonalities
62
  audio = remove_phone_tonalities(audio, sr)
63
  print("Phone tonalities removed")
64
+
65
  # Convert to AudioSegment for silence detection
66
  sound = AudioSegment.from_wav(audio_path)
67
+
68
+ # Silence detection
69
+ min_silence_len = 1000 # Minimum silence length in ms
70
+ silence_thresh = sound.dBFS - 14 # Threshold for silence detection
71
+
72
+ nonsilent_chunks = detect_nonsilent(sound, min_silence_len=min_silence_len, silence_thresh=silence_thresh)
73
+
74
+ # Apply diarization
 
 
 
 
 
 
 
 
 
75
  diarization = diarize_model(audio_path)
76
+
 
77
  if isinstance(diarization, pd.DataFrame):
78
+ diarization = diarization.to_dict(orient="records")
79
+
 
80
  transcriptions = []
81
+
82
+ for start, end in nonsilent_chunks:
83
+ chunk = sound[start:end]
84
  chunk.export("chunk.wav", format="wav")
85
+
86
+ # Track start time manually
87
+ chunk_start_time = start / 1000.0 # Convert ms to seconds
88
+
89
  chunk_audio, chunk_sr = librosa.load("chunk.wav", sr=None)
90
+ transcription = stt_pipeline(chunk_audio)
91
+
92
+ # Match transcription segment with diarization
93
  speaker_label = "Unknown"
94
  for speaker in diarization:
95
  spk_start, spk_end, label = speaker['start'], speaker['end'], speaker['label']
96
+ if spk_start <= chunk_start_time <= spk_end: # Use manually tracked start time
 
97
  speaker_label = label
98
  break
99
+
100
  transcriptions.append(f"Speaker {speaker_label}: {transcription['text']}")
101
+
102
+ os.remove("chunk.wav") # Clean up temporary file
 
103
 
104
  return "\n".join(transcriptions)
105