Kr08 commited on
Commit
6c36e37
·
verified ·
1 Parent(s): c569b48

Update audio_processing.py

Browse files
Files changed (1) hide show
  1. audio_processing.py +10 -75
audio_processing.py CHANGED
@@ -3,48 +3,42 @@ import whisper
3
  import numpy as np
4
  import torchaudio as ta
5
  import gradio as gr
6
- import spaces
7
  from model_utils import get_processor, get_model, get_whisper_model_small, get_device
8
  from config import SAMPLING_RATE, CHUNK_LENGTH_S
9
  import subprocess
10
 
11
- import subprocess
12
- import torchaudio as ta
13
-
14
-
15
  def resample_with_ffmpeg(input_file, output_file, target_sr=16000):
16
  command = [
17
  'ffmpeg', '-i', input_file, '-ar', str(target_sr), output_file
18
  ]
19
  subprocess.run(command, check=True)
20
 
21
- @spaces.GPU
22
  def detect_language(audio):
23
  whisper_model = get_whisper_model_small()
24
-
25
  # Save the input audio to a temporary file
26
  ta.save("input_audio.wav", torch.tensor(audio[1]).unsqueeze(0), audio[0])
27
-
28
  # Resample if necessary using ffmpeg
29
  if audio[0] != SAMPLING_RATE:
30
  resample_with_ffmpeg("input_audio.wav", "resampled_audio.wav", target_sr=SAMPLING_RATE)
31
  audio_tensor, _ = ta.load("resampled_audio.wav")
32
  else:
33
  audio_tensor = torch.tensor(audio[1]).float()
34
-
35
  # Ensure the audio is in the correct shape (mono)
36
  if audio_tensor.dim() == 2:
37
  audio_tensor = audio_tensor.mean(dim=0)
38
-
39
  # Use Whisper's preprocessing
40
  audio_tensor = whisper.pad_or_trim(audio_tensor)
41
  print(f"Audio length after pad/trim: {audio_tensor.shape[-1] / SAMPLING_RATE} seconds")
42
  mel = whisper.log_mel_spectrogram(audio_tensor).to(whisper_model.device)
43
-
44
  # Detect language
45
  _, probs = whisper_model.detect_language(mel)
46
  detected_lang = max(probs, key=probs.get)
47
-
48
  print(f"Audio shape: {audio_tensor.shape}")
49
  print(f"Mel spectrogram shape: {mel.shape}")
50
  print(f"Detected language: {detected_lang}")
@@ -52,73 +46,17 @@ def detect_language(audio):
52
 
53
  return detected_lang
54
 
55
-
56
- @spaces.GPU
57
  def process_long_audio(audio, task="transcribe", language=None):
58
- if audio[0] != SAMPLING_RATE:
59
- # Save the input audio to a file for ffmpeg processing
60
- ta.save("input_audio_1.wav", torch.tensor(audio[1]).unsqueeze(0), audio[0])
61
 
62
- # Resample using ffmpeg
63
- try:
64
- resample_with_ffmpeg("input_audio_1.wav", "resampled_audio_2.wav", target_sr=SAMPLING_RATE)
65
- except subprocess.CalledProcessError as e:
66
- print(f"ffmpeg failed: {e.stderr}")
67
- raise e
68
-
69
- waveform, _ = ta.load("resampled_audio_2.wav")
70
- else:
71
- waveform = torch.tensor(audio[1]).float()
72
-
73
- # Ensure the audio is in the correct shape (mono)
74
- if waveform.dim() == 2:
75
- waveform = waveform.mean(dim=0)
76
-
77
- print(f"Waveform shape after processing: {waveform.shape}")
78
-
79
- if waveform.numel() == 0:
80
- raise ValueError("Waveform is empty. Please check the input audio file.")
81
-
82
- input_length = waveform.shape[0] # Since waveform is 1D, access the length with shape[0]
83
- chunk_length = int(CHUNK_LENGTH_S * SAMPLING_RATE)
84
-
85
- # Corrected slicing for 1D tensor
86
- chunks = [waveform[i:i + chunk_length] for i in range(0, input_length, chunk_length)]
87
-
88
- # Initialize the processor
89
- processor = get_processor()
90
- model = get_model()
91
- device = get_device()
92
-
93
- results = []
94
- for chunk in chunks:
95
- input_features = processor(chunk, sampling_rate=SAMPLING_RATE, return_tensors="pt").input_features.to(device)
96
-
97
- with torch.no_grad():
98
- if task == "translate":
99
- forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task="translate")
100
- generated_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
101
- else:
102
- generated_ids = model.generate(input_features)
103
-
104
- transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
105
- results.extend(transcription)
106
-
107
- # Clear GPU cache
108
- torch.cuda.empty_cache()
109
-
110
- return " ".join(results)
111
-
112
-
113
- @spaces.GPU
114
  def process_audio(audio):
115
  if audio is None:
116
  return "No file uploaded", "", ""
117
-
118
  detected_lang = detect_language(audio)
119
  transcription = process_long_audio(audio, task="transcribe")
120
  translation = process_long_audio(audio, task="translate", language=detected_lang)
121
-
122
  return detected_lang, transcription, translation
123
 
124
  # Gradio interface
@@ -134,7 +72,4 @@ iface = gr.Interface(
134
  description="Upload an audio file to detect its language, transcribe, and translate it.",
135
  allow_flagging="never",
136
  css=".output-textbox { font-family: 'Noto Sans Devanagari', sans-serif; font-size: 18px; }"
137
- )
138
-
139
- if __name__ == "__main__":
140
- iface.launch()
 
3
  import numpy as np
4
  import torchaudio as ta
5
  import gradio as gr
 
6
  from model_utils import get_processor, get_model, get_whisper_model_small, get_device
7
  from config import SAMPLING_RATE, CHUNK_LENGTH_S
8
  import subprocess
9
 
 
 
 
 
10
  def resample_with_ffmpeg(input_file, output_file, target_sr=16000):
11
  command = [
12
  'ffmpeg', '-i', input_file, '-ar', str(target_sr), output_file
13
  ]
14
  subprocess.run(command, check=True)
15
 
 
16
  def detect_language(audio):
17
  whisper_model = get_whisper_model_small()
18
+
19
  # Save the input audio to a temporary file
20
  ta.save("input_audio.wav", torch.tensor(audio[1]).unsqueeze(0), audio[0])
21
+
22
  # Resample if necessary using ffmpeg
23
  if audio[0] != SAMPLING_RATE:
24
  resample_with_ffmpeg("input_audio.wav", "resampled_audio.wav", target_sr=SAMPLING_RATE)
25
  audio_tensor, _ = ta.load("resampled_audio.wav")
26
  else:
27
  audio_tensor = torch.tensor(audio[1]).float()
28
+
29
  # Ensure the audio is in the correct shape (mono)
30
  if audio_tensor.dim() == 2:
31
  audio_tensor = audio_tensor.mean(dim=0)
32
+
33
  # Use Whisper's preprocessing
34
  audio_tensor = whisper.pad_or_trim(audio_tensor)
35
  print(f"Audio length after pad/trim: {audio_tensor.shape[-1] / SAMPLING_RATE} seconds")
36
  mel = whisper.log_mel_spectrogram(audio_tensor).to(whisper_model.device)
37
+
38
  # Detect language
39
  _, probs = whisper_model.detect_language(mel)
40
  detected_lang = max(probs, key=probs.get)
41
+
42
  print(f"Audio shape: {audio_tensor.shape}")
43
  print(f"Mel spectrogram shape: {mel.shape}")
44
  print(f"Detected language: {detected_lang}")
 
46
 
47
  return detected_lang
48
 
 
 
49
  def process_long_audio(audio, task="transcribe", language=None):
50
+ # ... (rest of the function remains the same)
 
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def process_audio(audio):
53
  if audio is None:
54
  return "No file uploaded", "", ""
55
+
56
  detected_lang = detect_language(audio)
57
  transcription = process_long_audio(audio, task="transcribe")
58
  translation = process_long_audio(audio, task="translate", language=detected_lang)
59
+
60
  return detected_lang, transcription, translation
61
 
62
  # Gradio interface
 
72
  description="Upload an audio file to detect its language, transcribe, and translate it.",
73
  allow_flagging="never",
74
  css=".output-textbox { font-family: 'Noto Sans Devanagari', sans-serif; font-size: 18px; }"
75
+ )