EladSpamson commited on
Commit
aa43ea6
·
verified ·
1 Parent(s): 84b9b91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -3
app.py CHANGED
@@ -5,11 +5,12 @@ import librosa
5
  import os
6
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
7
 
8
- # Set Hugging Face cache directory explicitly
9
- os.environ["HF_HOME"] = "/tmp/hf_cache"
10
 
11
  app = Flask(__name__)
12
 
 
13
  model_id = "openai/whisper-base"
14
  processor = WhisperProcessor.from_pretrained(model_id)
15
  model = WhisperForConditionalGeneration.from_pretrained(model_id)
@@ -19,4 +20,43 @@ model.to(device)
19
 
20
  forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")
21
 
22
- # rest of the code remains unchanged...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import os
6
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
7
 
8
+ # Explicitly set writable cache directory
9
+ os.environ['HF_HOME'] = '/tmp/hf_cache'
10
 
11
  app = Flask(__name__)
12
 
13
+ # Temporarily using smaller model for faster testing
14
  model_id = "openai/whisper-base"
15
  processor = WhisperProcessor.from_pretrained(model_id)
16
  model = WhisperForConditionalGeneration.from_pretrained(model_id)
 
20
 
21
  forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")
22
 
23
+ def transcribe_audio(audio_url):
24
+ response = requests.get(audio_url)
25
+ audio_path = "/tmp/temp_audio.wav"
26
+ with open(audio_path, "wb") as f:
27
+ f.write(response.content)
28
+
29
+ waveform, sr = librosa.load(audio_path, sr=16000)
30
+ max_duration_sec = 3600
31
+ waveform = waveform[:sr * max_duration_sec]
32
+
33
+ chunk_duration_sec = 25
34
+ chunk_size = sr * chunk_duration_sec
35
+ chunks = [waveform[i:i + chunk_size] for i in range(0, len(waveform), chunk_size)]
36
+
37
+ partial_text = ""
38
+ for chunk in chunks:
39
+ inputs = processor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
40
+ input_features = inputs.input_features.to(device)
41
+
42
+ with torch.no_grad():
43
+ predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
44
+
45
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
46
+ partial_text += transcription + "\n"
47
+
48
+ return partial_text.strip()
49
+
50
+ @app.route('/transcribe', methods=['POST'])
51
+ def transcribe_endpoint():
52
+ data = request.get_json()
53
+ audio_url = data.get('audio_url')
54
+ if not audio_url:
55
+ return jsonify({"error": "Missing 'audio_url' in request"}), 400
56
+
57
+ transcription = transcribe_audio(audio_url)
58
+
59
+ return jsonify({"transcription": transcription})
60
+
61
+ if __name__ == '__main__':
62
+ app.run(host="0.0.0.0", port=7860)