Kr08 commited on
Commit
fd4f883
·
verified ·
1 Parent(s): 6db9237

Update audio_processing.py

Browse files
Files changed (1) hide show
  1. audio_processing.py +82 -68
audio_processing.py CHANGED
@@ -36,6 +36,9 @@ def load_models(model_size="small"):
36
  device = "cpu"
37
  compute_type = "int8"
38
  whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type)
 
 
 
39
 
40
  # Try to initialize diarization pipeline
41
  try:
@@ -55,8 +58,55 @@ def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000
55
  chunks.append(chunk)
56
  return chunks
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  @spaces.GPU
59
- def process_audio(audio_file, translate=False, model_size="small"):
60
  global whisper_model, diarization_pipeline
61
 
62
  if whisper_model is None:
@@ -66,71 +116,55 @@ def process_audio(audio_file, translate=False, model_size="small"):
66
 
67
  try:
68
  audio = whisperx.load_audio(audio_file)
 
69
 
70
- # Perform diarization if pipeline is available
71
  diarization_result = None
72
- if diarization_pipeline is not None:
73
- try:
74
- diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
75
- except Exception as e:
76
- logger.warning(f"Diarization failed: {str(e)}. Proceeding without diarization.")
77
-
78
- chunks = preprocess_audio(audio)
 
79
 
80
  language_segments = []
81
  final_segments = []
82
 
83
- overlap_duration = 2 # 2 seconds overlap
84
- for i, chunk in enumerate(chunks):
85
- chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
86
- chunk_end_time = chunk_start_time + CHUNK_LENGTH
87
- logger.info(f"Processing chunk {i+1}/{len(chunks)}")
88
- lang = whisper_model.detect_language(chunk)
89
- result_transcribe = whisper_model.transcribe(chunk, language=lang)
90
- if translate:
91
- result_translate = whisper_model.transcribe(chunk, task="translate")
92
- chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
93
- for j, t_seg in enumerate(result_transcribe["segments"]):
94
- segment_start = chunk_start_time + t_seg["start"]
95
- segment_end = chunk_start_time + t_seg["end"]
96
- # Skip segments in the overlapping region of the previous chunk
97
- if i > 0 and segment_end <= chunk_start_time + overlap_duration:
98
- print(f"Skipping segment in overlap with previous chunk: {segment_start:.2f} - {segment_end:.2f}")
99
- continue
100
 
101
- # Skip segments in the overlapping region of the next chunk
102
- if i < len(chunks) - 1 and segment_start >= chunk_end_time - overlap_duration:
103
- print(f"Skipping segment in overlap with next chunk: {segment_start:.2f} - {segment_end:.2f}")
104
- continue
105
 
 
 
 
 
106
  speaker = "Unknown"
107
  if diarization_result is not None:
108
- speakers = []
109
- for turn, track, spk in diarization_result.itertracks(yield_label=True):
110
- if turn.start <= segment_end and turn.end >= segment_start:
111
- speakers.append(spk)
112
- speaker = max(set(speakers), key=speakers.count) if speakers else "Unknown"
113
-
114
- segment = {
115
  "start": segment_start,
116
  "end": segment_end,
117
  "language": lang,
118
  "speaker": speaker,
119
- "text": t_seg["text"],
120
  }
121
 
122
  if translate:
123
- segment["translated"] = result_translate["segments"][j]["text"]
 
 
 
124
 
125
- final_segments.append(segment)
126
-
127
  language_segments.append({
128
  "language": lang,
129
- "start": chunk_start_time,
130
- "end": chunk_start_time + CHUNK_LENGTH
131
  })
132
- chunk_end_time = time.time()
133
- logger.info(f"Chunk {i+1} processed in {chunk_end_time - chunk_start_time:.2f} seconds")
134
 
135
  final_segments.sort(key=lambda x: x["start"])
136
  merged_segments = merge_nearby_segments(final_segments)
@@ -143,26 +177,6 @@ def process_audio(audio_file, translate=False, model_size="small"):
143
  logger.error(f"An error occurred during audio processing: {str(e)}")
144
  raise
145
 
146
- def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
147
- merged = []
148
- for segment in segments:
149
- if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
150
- merged.append(segment)
151
- else:
152
- # Find the overlap
153
- matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
154
- match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
155
-
156
- if match.size / len(segment['text']) > similarity_threshold:
157
- # Merge the segments
158
- merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
159
- merged_translated = merged[-1].get('translated', '') + segment.get('translated', '')[match.b + match.size:]
160
-
161
- merged[-1]['end'] = segment['end']
162
- merged[-1]['text'] = merged_text
163
- if 'translated' in segment:
164
- merged[-1]['translated'] = merged_translated
165
- else:
166
- # If no significant overlap, append as a new segment
167
- merged.append(segment)
168
- return merged
 
36
  device = "cpu"
37
  compute_type = "int8"
38
  whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type)
39
+
40
+ def load_diarization_pipeline():
41
+ global diarization_pipeline, device
42
 
43
  # Try to initialize diarization pipeline
44
  try:
 
58
  chunks.append(chunk)
59
  return chunks
60
 
61
+ def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
62
+ merged = []
63
+ for segment in segments:
64
+ if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
65
+ merged.append(segment)
66
+ else:
67
+ # Find the overlap
68
+ matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
69
+ match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
70
+
71
+ if match.size / len(segment['text']) > similarity_threshold:
72
+ # Merge the segments
73
+ merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
74
+ merged_translated = merged[-1].get('translated', '') + segment.get('translated', '')[match.b + match.size:]
75
+
76
+ merged[-1]['end'] = segment['end']
77
+ merged[-1]['text'] = merged_text
78
+ if 'translated' in segment:
79
+ merged[-1]['translated'] = merged_translated
80
+ else:
81
+ # If no significant overlap, append as a new segment
82
+ merged.append(segment)
83
+ return merged
84
+
85
+ # Helper function to get the most common speaker in a time range
86
+ def get_most_common_speaker(diarization_result, start_time, end_time):
87
+ speakers = []
88
+ for turn, _, speaker in diarization_result.itertracks(yield_label=True):
89
+ if turn.start <= end_time and turn.end >= start_time:
90
+ speakers.append(speaker)
91
+ return max(set(speakers), key=speakers.count) if speakers else "Unknown"
92
+
93
+ # Helper function to split long audio files
94
+ def split_audio(audio, max_duration=30):
95
+ sample_rate = 16000
96
+ max_samples = max_duration * sample_rate
97
+
98
+ if len(audio) <= max_samples:
99
+ return [audio]
100
+
101
+ splits = []
102
+ for i in range(0, len(audio), max_samples):
103
+ splits.append(audio[i:i+max_samples])
104
+
105
+ return splits
106
+
107
+ # Main processing function with optimizations
108
  @spaces.GPU
109
+ def process_audio_optimized(audio_file, translate=False, model_size="small", use_diarization=True):
110
  global whisper_model, diarization_pipeline
111
 
112
  if whisper_model is None:
 
116
 
117
  try:
118
  audio = whisperx.load_audio(audio_file)
119
+ audio_splits = split_audio(audio)
120
 
121
+ # Perform diarization if requested and pipeline is available
122
  diarization_result = None
123
+ if use_diarization:
124
+ if diarization_pipeline is None:
125
+ load_diarization_pipeline()
126
+ if diarization_pipeline is not None:
127
+ try:
128
+ diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
129
+ except Exception as e:
130
+ logger.warning(f"Diarization failed: {str(e)}. Proceeding without diarization.")
131
 
132
  language_segments = []
133
  final_segments = []
134
 
135
+ for i, audio_split in enumerate(audio_splits):
136
+ logger.info(f"Processing split {i+1}/{len(audio_splits)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ result = whisper_model.transcribe(audio_split)
139
+ lang = result["language"]
 
 
140
 
141
+ for segment in result["segments"]:
142
+ segment_start = segment["start"] + (i * 30) # Adjust start time based on split
143
+ segment_end = segment["end"] + (i * 30) # Adjust end time based on split
144
+
145
  speaker = "Unknown"
146
  if diarization_result is not None:
147
+ speaker = get_most_common_speaker(diarization_result, segment_start, segment_end)
148
+
149
+ final_segment = {
 
 
 
 
150
  "start": segment_start,
151
  "end": segment_end,
152
  "language": lang,
153
  "speaker": speaker,
154
+ "text": segment["text"],
155
  }
156
 
157
  if translate:
158
+ translation = whisper_model.transcribe(audio_split[int(segment["start"]*16000):int(segment["end"]*16000)], task="translate")
159
+ final_segment["translated"] = translation["text"]
160
+
161
+ final_segments.append(final_segment)
162
 
 
 
163
  language_segments.append({
164
  "language": lang,
165
+ "start": i * 30,
166
+ "end": min((i + 1) * 30, len(audio) / 16000)
167
  })
 
 
168
 
169
  final_segments.sort(key=lambda x: x["start"])
170
  merged_segments = merge_nearby_segments(final_segments)
 
177
  logger.error(f"An error occurred during audio processing: {str(e)}")
178
  raise
179
 
180
+ # You can keep the original process_audio function for backwards compatibility
181
+ # or replace it with the optimized version
182
+ process_audio = process_audio_optimized