Coco-18 commited on
Commit
f03b779
·
verified ·
1 Parent(s): ddc1d69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -30
app.py CHANGED
@@ -14,6 +14,8 @@ for path in ["/tmp/hf_home", "/tmp/transformers_cache", "/tmp/huggingface_hub_ca
14
 
15
  # Now import the rest of the libraries
16
  import torch
 
 
17
  import torchaudio
18
  import soundfile as sf
19
  from flask import Flask, request, jsonify, send_file
@@ -98,49 +100,51 @@ def transcribe_audio():
98
  audio_file = request.files["audio"]
99
  language = request.form.get("language", "english").lower()
100
 
101
- # Validate language
102
  if language not in LANGUAGE_CODES:
103
  return jsonify({"error": f"Unsupported language: {language}"}), 400
104
 
105
- # Get the language code for the ASR model
106
  lang_code = LANGUAGE_CODES[language]
107
 
108
- # Save audio file temporarily
109
- audio_path = os.path.join(OUTPUT_DIR, "input_audio")
110
- audio_file.save(audio_path)
111
-
112
- # Load and process audio
113
- try:
114
- # Load audio using torchaudio, which supports various formats
115
- waveform, sr = torchaudio.load(audio_path)
116
-
117
- # Resample if necessary
118
- if sr != SAMPLE_RATE:
119
- waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
120
-
121
- # Normalize audio (recommended for Wav2Vec2)
122
- waveform = waveform / torch.max(torch.abs(waveform))
123
-
124
- # Process audio for ASR
125
- inputs = asr_processor(
126
- waveform.squeeze().numpy(),
127
- sampling_rate=SAMPLE_RATE,
128
- return_tensors="pt",
129
- language=lang_code # Set the language code
130
- )
131
- except Exception as e:
132
- return jsonify({"error": f"Error processing audio: {str(e)}"}), 400
133
-
134
- # Transcribe
 
 
 
 
135
  with torch.no_grad():
136
  logits = asr_model(**inputs).logits
137
  ids = torch.argmax(logits, dim=-1)[0]
138
  transcription = asr_processor.decode(ids)
139
 
140
- # Log the transcription
141
  print(f"Transcription ({language}): {transcription}")
142
 
143
  return jsonify({"transcription": transcription})
 
144
  except Exception as e:
145
  print(f"ASR error: {str(e)}")
146
  return jsonify({"error": f"ASR failed: {str(e)}"}), 500
 
14
 
15
  # Now import the rest of the libraries
16
  import torch
17
+ from pydub import AudioSegment
18
+ import tempfile
19
  import torchaudio
20
  import soundfile as sf
21
  from flask import Flask, request, jsonify, send_file
 
100
  audio_file = request.files["audio"]
101
  language = request.form.get("language", "english").lower()
102
 
 
103
  if language not in LANGUAGE_CODES:
104
  return jsonify({"error": f"Unsupported language: {language}"}), 400
105
 
 
106
  lang_code = LANGUAGE_CODES[language]
107
 
108
+ # Save the uploaded file temporarily
109
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio_file.filename)[-1]) as temp_audio:
110
+ temp_audio.write(audio_file.read())
111
+ temp_audio_path = temp_audio.name
112
+
113
+ # Convert to WAV if necessary
114
+ wav_path = temp_audio_path
115
+ if not audio_file.filename.lower().endswith(".wav"):
116
+ wav_path = os.path.join(OUTPUT_DIR, "converted_audio.wav")
117
+ audio = AudioSegment.from_file(temp_audio_path)
118
+ audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(1)
119
+ audio.export(wav_path, format="wav")
120
+
121
+ # Load and process the WAV file
122
+ waveform, sr = torchaudio.load(wav_path)
123
+
124
+ # Resample if needed
125
+ if sr != SAMPLE_RATE:
126
+ waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
127
+
128
+ waveform = waveform / torch.max(torch.abs(waveform))
129
+
130
+ # Process audio for ASR
131
+ inputs = asr_processor(
132
+ waveform.squeeze().numpy(),
133
+ sampling_rate=SAMPLE_RATE,
134
+ return_tensors="pt",
135
+ language=lang_code
136
+ )
137
+
138
+ # Perform ASR
139
  with torch.no_grad():
140
  logits = asr_model(**inputs).logits
141
  ids = torch.argmax(logits, dim=-1)[0]
142
  transcription = asr_processor.decode(ids)
143
 
 
144
  print(f"Transcription ({language}): {transcription}")
145
 
146
  return jsonify({"transcription": transcription})
147
+
148
  except Exception as e:
149
  print(f"ASR error: {str(e)}")
150
  return jsonify({"error": f"ASR failed: {str(e)}"}), 500