Kr08 commited on
Commit
bbbe230
·
verified ·
1 Parent(s): 5449862

Update audio_processing.py

Browse files
Files changed (1) hide show
  1. audio_processing.py +89 -116
audio_processing.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import whisperx
3
  import torch
4
  import numpy as np
@@ -10,50 +9,21 @@ load_dotenv()
10
  import logging
11
  import time
12
  from difflib import SequenceMatcher
13
- import spaces
14
-
15
  hf_token = os.getenv("HF_TOKEN")
16
 
17
- CHUNK_LENGTH = 5
18
- OVERLAP = 2
 
 
 
 
19
 
20
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
21
  logger = logging.getLogger(__name__)
 
22
 
23
 
24
- @spaces.GPU(duration=60)
25
- def load_whisper_model(model_size="small"):
26
- logger.info(f"Loading Whisper model (size: {model_size})...")
27
- device = "cuda" if torch.cuda.is_available() else "cpu"
28
- compute_type = "float16" if device == "cuda" else "int8"
29
- try:
30
- model = whisperx.load_model(model_size, device, compute_type=compute_type)
31
- logger.info(f"Whisper model loaded successfully on {device}")
32
- return model
33
- except RuntimeError as e:
34
- logger.warning(f"Failed to load Whisper model on {device}. Falling back to CPU. Error: {str(e)}")
35
- device = "cpu"
36
- compute_type = "int8"
37
- model = whisperx.load_model(model_size, device, compute_type=compute_type)
38
- logger.info("Whisper model loaded successfully on CPU")
39
- return model
40
-
41
-
42
- @spaces.GPU(duration=60)
43
- def load_diarization_pipeline():
44
- logger.info("Loading diarization pipeline...")
45
- try:
46
- pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
47
- if torch.cuda.is_available():
48
- pipeline = pipeline.to(torch.device("cuda"))
49
- logger.info("Diarization pipeline loaded successfully")
50
- return pipeline
51
- except Exception as e:
52
- logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.")
53
- return None
54
-
55
-
56
- def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):
57
  chunks = []
58
  for i in range(0, len(audio), chunk_size - overlap):
59
  chunk = audio[i:i+chunk_size]
@@ -62,103 +32,75 @@ def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000
62
  chunks.append(chunk)
63
  return chunks
64
 
65
-
66
- def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
67
- merged = []
68
- for segment in segments:
69
- if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
70
- merged.append(segment)
71
- else:
72
- matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
73
- match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
74
-
75
- if match.size / len(segment['text']) > similarity_threshold:
76
- merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
77
- merged_translated = merged[-1].get('translated', '') + segment.get('translated', '')[match.b + match.size:]
78
-
79
- merged[-1]['end'] = segment['end']
80
- merged[-1]['text'] = merged_text
81
- if 'translated' in segment:
82
- merged[-1]['translated'] = merged_translated
83
- else:
84
- merged.append(segment)
85
- return merged
86
-
87
- def get_most_common_speaker(diarization_result, start_time, end_time):
88
- speakers = []
89
- for turn, _, speaker in diarization_result.itertracks(yield_label=True):
90
- if turn.start <= end_time and turn.end >= start_time:
91
- speakers.append(speaker)
92
- return max(set(speakers), key=speakers.count) if speakers else "Unknown"
93
-
94
- def split_audio(audio, max_duration=30):
95
- sample_rate = 16000
96
- max_samples = max_duration * sample_rate
97
-
98
- if len(audio) <= max_samples:
99
- return [audio]
100
-
101
- splits = []
102
- for i in range(0, len(audio), max_samples):
103
- splits.append(audio[i:i+max_samples])
104
-
105
- return splits
106
-
107
- @spaces.GPU(duration=60)
108
- def process_audio(audio_file, translate=False, model_size="small", use_diarization=True):
109
- logger.info(f"Starting audio processing: translate={translate}, model_size={model_size}, use_diarization={use_diarization}")
110
  start_time = time.time()
111
 
112
  try:
113
- whisper_model = load_whisper_model(model_size)
 
 
114
  audio = whisperx.load_audio(audio_file)
115
- audio_splits = split_audio(audio)
 
 
 
 
 
116
 
117
- diarization_result = None
118
- if use_diarization:
119
- diarization_pipeline = load_diarization_pipeline()
120
- if diarization_pipeline is not None:
121
- try:
122
- diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
123
- except Exception as e:
124
- logger.warning(f"Diarization failed: {str(e)}. Proceeding without diarization.")
125
 
126
  language_segments = []
127
  final_segments = []
128
 
129
- for i, audio_split in enumerate(audio_splits):
130
- logger.info(f"Processing split {i+1}/{len(audio_splits)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- result = whisper_model.transcribe(audio_split)
133
- lang = result["language"]
 
 
134
 
135
- for segment in result["segments"]:
136
- segment_start = segment["start"] + (i * 30)
137
- segment_end = segment["end"] + (i * 30)
138
-
139
- speaker = "Unknown"
140
- if diarization_result is not None:
141
- speaker = get_most_common_speaker(diarization_result, segment_start, segment_end)
142
-
143
- final_segment = {
144
  "start": segment_start,
145
  "end": segment_end,
146
  "language": lang,
147
- "speaker": speaker,
148
- "text": segment["text"],
149
  }
150
 
151
  if translate:
152
- translation = whisper_model.transcribe(audio_split[int(segment["start"]*16000):int(segment["end"]*16000)], task="translate")
153
- final_segment["translated"] = translation["text"]
154
-
155
- final_segments.append(final_segment)
156
 
 
 
157
  language_segments.append({
158
  "language": lang,
159
- "start": i * 30,
160
- "end": min((i + 1) * 30, len(audio) / 16000)
161
  })
 
 
162
 
163
  final_segments.sort(key=lambda x: x["start"])
164
  merged_segments = merge_nearby_segments(final_segments)
@@ -166,7 +108,38 @@ def process_audio(audio_file, translate=False, model_size="small", use_diarizati
166
  end_time = time.time()
167
  logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")
168
 
169
- return language_segments, merged_segments
170
  except Exception as e:
171
  logger.error(f"An error occurred during audio processing: {str(e)}")
172
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import whisperx
2
  import torch
3
  import numpy as np
 
9
  import logging
10
  import time
11
  from difflib import SequenceMatcher
 
 
12
  hf_token = os.getenv("HF_TOKEN")
13
 
14
+ CHUNK_LENGTH=5
15
+ OVERLAP=0
16
+ import whisperx
17
+ import torch
18
+ import numpy as np
19
+
20
 
21
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
22
  logger = logging.getLogger(__name__)
23
+ import spaces
24
 
25
 
26
+ def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000): # 2 seconds overlap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  chunks = []
28
  for i in range(0, len(audio), chunk_size - overlap):
29
  chunk = audio[i:i+chunk_size]
 
32
  chunks.append(chunk)
33
  return chunks
34
 
35
+ @spaces.GPU()
36
+ def process_audio(audio_file, translate=False, model_size="small"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  start_time = time.time()
38
 
39
  try:
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ print(f"Using device: {device}")
42
+ compute_type = "int8" if torch.cuda.is_available() else "float32"
43
  audio = whisperx.load_audio(audio_file)
44
+ model = whisperx.load_model(model_size, device, compute_type=compute_type)
45
+
46
+ diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
47
+ diarization_pipeline = diarization_pipeline.to(torch.device(device))
48
+
49
+ diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
50
 
51
+ chunks = preprocess_audio(audio)
 
 
 
 
 
 
 
52
 
53
  language_segments = []
54
  final_segments = []
55
 
56
+ overlap_duration = OVERLAP # 2 seconds overlap
57
+ for i, chunk in enumerate(chunks):
58
+ chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
59
+ chunk_end_time = chunk_start_time + CHUNK_LENGTH
60
+ logger.info(f"Processing chunk {i+1}/{len(chunks)}")
61
+ lang = model.detect_language(chunk)
62
+ result_transcribe = model.transcribe(chunk, language=lang)
63
+ if translate:
64
+ result_translate = model.transcribe(chunk, task="translate")
65
+ chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
66
+ for j, t_seg in enumerate(result_transcribe["segments"]):
67
+ segment_start = chunk_start_time + t_seg["start"]
68
+ segment_end = chunk_start_time + t_seg["end"]
69
+ # Skip segments in the overlapping region of the previous chunk
70
+ if i > 0 and segment_end <= chunk_start_time + overlap_duration:
71
+ print(f"Skipping segment in overlap with previous chunk: {segment_start:.2f} - {segment_end:.2f}")
72
+ continue
73
 
74
+ # Skip segments in the overlapping region of the next chunk
75
+ if i < len(chunks) - 1 and segment_start >= chunk_end_time - overlap_duration:
76
+ print(f"Skipping segment in overlap with next chunk: {segment_start:.2f} - {segment_end:.2f}")
77
+ continue
78
 
79
+ speakers = []
80
+ for turn, track, speaker in diarization_result.itertracks(yield_label=True):
81
+ if turn.start <= segment_end and turn.end >= segment_start:
82
+ speakers.append(speaker)
83
+
84
+ segment = {
 
 
 
85
  "start": segment_start,
86
  "end": segment_end,
87
  "language": lang,
88
+ "speaker": max(set(speakers), key=speakers.count) if speakers else "Unknown",
89
+ "text": t_seg["text"],
90
  }
91
 
92
  if translate:
93
+ segment["translated"] = result_translate["segments"][j]["text"]
 
 
 
94
 
95
+ final_segments.append(segment)
96
+
97
  language_segments.append({
98
  "language": lang,
99
+ "start": chunk_start_time,
100
+ "end": chunk_start_time + CHUNK_LENGTH
101
  })
102
+ chunk_end_time = time.time()
103
+ logger.info(f"Chunk {i+1} processed in {chunk_end_time - chunk_start_time:.2f} seconds")
104
 
105
  final_segments.sort(key=lambda x: x["start"])
106
  merged_segments = merge_nearby_segments(final_segments)
 
108
  end_time = time.time()
109
  logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")
110
 
111
+ return language_segments, final_segments
112
  except Exception as e:
113
  logger.error(f"An error occurred during audio processing: {str(e)}")
114
+ raise
115
+
116
+ def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.9):
117
+ merged = []
118
+ for segment in segments:
119
+ if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
120
+ merged.append(segment)
121
+ else:
122
+ # Find the overlap
123
+ matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
124
+ match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
125
+
126
+ if match.size / len(segment['text']) > similarity_threshold:
127
+ # Merge the segments
128
+ merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
129
+ merged_translated = merged[-1]['translated'] + segment['translated'][match.b + match.size:]
130
+
131
+ merged[-1]['end'] = segment['end']
132
+ merged[-1]['text'] = merged_text
133
+ merged[-1]['translated'] = merged_translated
134
+ else:
135
+ # If no significant overlap, append as a new segment
136
+ merged.append(segment)
137
+ return merged
138
+
139
+ def print_results(segments):
140
+ for segment in segments:
141
+ print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:")
142
+ print(f"Original: {segment['text']}")
143
+ if 'translated' in segment:
144
+ print(f"Translated: {segment['translated']}")
145
+ print()