EladSpamson commited on
Commit
66d0ca2
·
verified ·
1 Parent(s): 1c7c059

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -17
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
 
3
- # Ensure environment variables are set before Transformers are imported
4
  os.environ["HF_HOME"] = "/tmp/hf_cache"
5
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
6
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
@@ -14,7 +14,7 @@ from transformers import WhisperProcessor, WhisperForConditionalGeneration
14
 
15
  app = Flask(__name__)
16
 
17
- # Using a smaller model for faster CPU loading
18
  model_id = "openai/whisper-base"
19
  processor = WhisperProcessor.from_pretrained(model_id)
20
  model = WhisperForConditionalGeneration.from_pretrained(model_id)
@@ -22,6 +22,9 @@ model = WhisperForConditionalGeneration.from_pretrained(model_id)
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  model.to(device)
24
 
 
 
 
25
  def transcribe_audio(audio_url):
26
  # 1) Download audio file to /tmp
27
  response = requests.get(audio_url)
@@ -29,41 +32,45 @@ def transcribe_audio(audio_url):
29
  with open(audio_path, "wb") as f:
30
  f.write(response.content)
31
 
32
- # 2) Load audio with librosa
33
  waveform, sr = librosa.load(audio_path, sr=16000)
34
 
35
- # 3) Optional safety limit (1 hour)
36
- max_duration_sec = 3600
37
- waveform = waveform[:sr * max_duration_sec]
38
 
39
- # 4) Split into smaller chunks (25s)
40
- chunk_duration_sec = 25
41
- chunk_size = sr * chunk_duration_sec
42
  chunks = [waveform[i : i + chunk_size] for i in range(0, len(waveform), chunk_size)]
43
 
44
  partial_text = ""
45
  for chunk in chunks:
46
- inputs = processor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
 
47
  input_features = inputs.input_features.to(device)
48
 
49
- # **No** forced_decoder_ids => Whisper auto-detects language
50
  with torch.no_grad():
51
- predicted_ids = model.generate(input_features)
 
 
 
 
52
 
53
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
54
  partial_text += transcription + "\n"
55
 
56
  return partial_text.strip()
57
 
58
- @app.route('/transcribe', methods=['POST'])
59
  def transcribe_endpoint():
60
  data = request.get_json()
61
- audio_url = data.get('audio_url')
62
  if not audio_url:
63
  return jsonify({"error": "Missing 'audio_url' in request"}), 400
64
 
65
- transcription = transcribe_audio(audio_url)
66
- return jsonify({"transcription": transcription})
67
 
68
- if __name__ == '__main__':
69
  app.run(host="0.0.0.0", port=7860)
 
1
  import os
2
 
3
+ # Set environment variables so HF uses /tmp for caching
4
  os.environ["HF_HOME"] = "/tmp/hf_cache"
5
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
6
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
 
14
 
15
  app = Flask(__name__)
16
 
17
+ # Use a multilingual model capable of Hebrew (e.g. whisper-base)
18
  model_id = "openai/whisper-base"
19
  processor = WhisperProcessor.from_pretrained(model_id)
20
  model = WhisperForConditionalGeneration.from_pretrained(model_id)
 
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  model.to(device)
24
 
25
+ # Force Hebrew transcription tokens so no auto-detect occurs
26
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")
27
+
28
  def transcribe_audio(audio_url):
29
  # 1) Download audio file to /tmp
30
  response = requests.get(audio_url)
 
32
  with open(audio_path, "wb") as f:
33
  f.write(response.content)
34
 
35
+ # 2) Load with librosa
36
  waveform, sr = librosa.load(audio_path, sr=16000)
37
 
38
+ # 3) Optional: limit to 1 hour
39
+ max_sec = 3600
40
+ waveform = waveform[: sr * max_sec]
41
 
42
+ # 4) Split into 25-second chunks (or pick any chunk size)
43
+ chunk_sec = 25
44
+ chunk_size = sr * chunk_sec
45
  chunks = [waveform[i : i + chunk_size] for i in range(0, len(waveform), chunk_size)]
46
 
47
  partial_text = ""
48
  for chunk in chunks:
49
+ # Preprocess chunk to mel
50
+ inputs = processor(chunk, sampling_rate=sr, return_tensors="pt", padding=True)
51
  input_features = inputs.input_features.to(device)
52
 
 
53
  with torch.no_grad():
54
+ # Force Hebrew so no meltdown on short audio
55
+ predicted_ids = model.generate(
56
+ input_features,
57
+ forced_decoder_ids=forced_decoder_ids
58
+ )
59
 
60
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
61
  partial_text += transcription + "\n"
62
 
63
  return partial_text.strip()
64
 
65
+ @app.route("/transcribe", methods=["POST"])
66
  def transcribe_endpoint():
67
  data = request.get_json()
68
+ audio_url = data.get("audio_url")
69
  if not audio_url:
70
  return jsonify({"error": "Missing 'audio_url' in request"}), 400
71
 
72
+ text = transcribe_audio(audio_url)
73
+ return jsonify({"transcription": text})
74
 
75
+ if __name__ == "__main__":
76
  app.run(host="0.0.0.0", port=7860)