EladSpamson commited on
Commit
aefce6b
·
verified ·
1 Parent(s): 5b89128

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -12
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
 
3
- # Must set environment variables before importing Transformers
4
  os.environ["HF_HOME"] = "/tmp/hf_cache"
5
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
6
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
@@ -15,7 +15,7 @@ from transformers import WhisperProcessor, WhisperForConditionalGeneration
15
 
16
  app = Flask(__name__)
17
 
18
- # Use your custom Hebrew Whisper model
19
  model_id = "ivrit-ai/whisper-large-v3-turbo"
20
  processor = WhisperProcessor.from_pretrained(model_id)
21
  model = WhisperForConditionalGeneration.from_pretrained(model_id)
@@ -23,22 +23,21 @@ model = WhisperForConditionalGeneration.from_pretrained(model_id)
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
  model.to(device)
25
 
26
- # Force Hebrew so it won't require 30s audio for language detection
27
  forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")
28
 
29
  def transcribe_audio(audio_url):
30
- # 1) Download audio file to /tmp
31
  response = requests.get(audio_url)
32
  audio_path = "/tmp/temp_audio.wav"
33
  with open(audio_path, "wb") as f:
34
  f.write(response.content)
35
 
36
- # 2) Load with librosa
37
  waveform, sr = librosa.load(audio_path, sr=16000)
38
 
39
- # 3) (Optional) limit up to 1 hour
40
- max_sec = 3600
41
- waveform = waveform[: sr * max_sec]
42
 
43
  # 4) Split into 25-second chunks
44
  chunk_sec = 25
@@ -47,11 +46,10 @@ def transcribe_audio(audio_url):
47
 
48
  partial_text = ""
49
  for chunk in chunks:
50
- # Preprocess chunk → mel spectrogram
51
  inputs = processor(chunk, sampling_rate=sr, return_tensors="pt", padding=True)
52
  input_features = inputs.input_features.to(device)
53
 
54
- # Force Hebrew, skipping auto-detect
55
  with torch.no_grad():
56
  predicted_ids = model.generate(
57
  input_features,
@@ -70,10 +68,9 @@ def transcribe_endpoint():
70
  if not audio_url:
71
  return jsonify({"error": "Missing 'audio_url' in request"}), 400
72
 
73
- # Perform forced Hebrew transcription
74
  text = transcribe_audio(audio_url)
75
 
76
- # Return raw Hebrew in JSON
77
  payload = {"Transcription": text}
78
  return Response(
79
  json.dumps(payload, ensure_ascii=False),
 
1
  import os
2
 
3
+ # Environment variables to avoid permission issues
4
  os.environ["HF_HOME"] = "/tmp/hf_cache"
5
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
6
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
 
15
 
16
  app = Flask(__name__)
17
 
18
+ # Use your custom Hebrew Whisper model (example: ivrit-ai/whisper-large-v3-turbo)
19
  model_id = "ivrit-ai/whisper-large-v3-turbo"
20
  processor = WhisperProcessor.from_pretrained(model_id)
21
  model = WhisperForConditionalGeneration.from_pretrained(model_id)
 
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
  model.to(device)
25
 
26
+ # Force Hebrew to skip auto-detect
27
  forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")
28
 
29
  def transcribe_audio(audio_url):
30
+ # 1) Download audio file
31
  response = requests.get(audio_url)
32
  audio_path = "/tmp/temp_audio.wav"
33
  with open(audio_path, "wb") as f:
34
  f.write(response.content)
35
 
36
+ # 2) Load audio with librosa
37
  waveform, sr = librosa.load(audio_path, sr=16000)
38
 
39
+ # 3) Limit to 1 hour
40
+ waveform = waveform[: sr * 3600]
 
41
 
42
  # 4) Split into 25-second chunks
43
  chunk_sec = 25
 
46
 
47
  partial_text = ""
48
  for chunk in chunks:
 
49
  inputs = processor(chunk, sampling_rate=sr, return_tensors="pt", padding=True)
50
  input_features = inputs.input_features.to(device)
51
 
52
+ # Generate forced-Hebrew transcription
53
  with torch.no_grad():
54
  predicted_ids = model.generate(
55
  input_features,
 
68
  if not audio_url:
69
  return jsonify({"error": "Missing 'audio_url' in request"}), 400
70
 
 
71
  text = transcribe_audio(audio_url)
72
 
73
+ # Return Hebrew characters directly
74
  payload = {"Transcription": text}
75
  return Response(
76
  json.dumps(payload, ensure_ascii=False),