EladSpamson commited on
Commit
a6aa54b
·
verified ·
1 Parent(s): 43cc63d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -11
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import os
2
 
3
- # Ensure environment variables are set before loading transformers
4
  os.environ["HF_HOME"] = "/tmp/hf_cache"
5
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
 
 
6
 
7
  from flask import Flask, request, jsonify
8
  import requests
@@ -12,7 +14,7 @@ from transformers import WhisperProcessor, WhisperForConditionalGeneration
12
 
13
  app = Flask(__name__)
14
 
15
- # Use a smaller model for faster CPU loading
16
  model_id = "openai/whisper-base"
17
  processor = WhisperProcessor.from_pretrained(model_id)
18
  model = WhisperForConditionalGeneration.from_pretrained(model_id)
@@ -23,16 +25,16 @@ model.to(device)
23
  forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")
24
 
25
  def transcribe_audio(audio_url):
26
- # 1) Download the file to /tmp
27
  response = requests.get(audio_url)
28
  audio_path = "/tmp/temp_audio.wav"
29
  with open(audio_path, "wb") as f:
30
  f.write(response.content)
31
 
32
- # 2) Load audio with librosa
33
  waveform, sr = librosa.load(audio_path, sr=16000)
34
 
35
- # 3) Truncate to 1 hour max
36
  max_duration_sec = 3600
37
  waveform = waveform[:sr * max_duration_sec]
38
 
@@ -43,18 +45,15 @@ def transcribe_audio(audio_url):
43
 
44
  partial_text = ""
45
  for chunk in chunks:
46
- # Preprocess chunk
47
  inputs = processor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
48
  input_features = inputs.input_features.to(device)
49
 
50
- # Transcribe chunk
51
  with torch.no_grad():
52
  predicted_ids = model.generate(
53
- input_features,
54
  forced_decoder_ids=forced_decoder_ids
55
  )
56
 
57
- # Convert IDs back to text
58
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
59
  partial_text += transcription + "\n"
60
 
@@ -67,10 +66,8 @@ def transcribe_endpoint():
67
  if not audio_url:
68
  return jsonify({"error": "Missing 'audio_url' in request"}), 400
69
 
70
- # Perform transcription
71
  transcription = transcribe_audio(audio_url)
72
  return jsonify({"transcription": transcription})
73
 
74
  if __name__ == '__main__':
75
- # Run the Flask app on port 7860
76
  app.run(host="0.0.0.0", port=7860)
 
1
  import os
2
 
3
+ # Set environment variables VERY early, before HF or Transformers are imported:
4
  os.environ["HF_HOME"] = "/tmp/hf_cache"
5
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
6
+ os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
7
+ os.environ["XDG_CACHE_HOME"] = "/tmp"
8
 
9
  from flask import Flask, request, jsonify
10
  import requests
 
14
 
15
  app = Flask(__name__)
16
 
17
+ # Use a smaller model for CPU
18
  model_id = "openai/whisper-base"
19
  processor = WhisperProcessor.from_pretrained(model_id)
20
  model = WhisperForConditionalGeneration.from_pretrained(model_id)
 
25
  forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")
26
 
27
  def transcribe_audio(audio_url):
28
+ # 1) Download audio file to /tmp
29
  response = requests.get(audio_url)
30
  audio_path = "/tmp/temp_audio.wav"
31
  with open(audio_path, "wb") as f:
32
  f.write(response.content)
33
 
34
+ # 2) Load with librosa
35
  waveform, sr = librosa.load(audio_path, sr=16000)
36
 
37
+ # 3) Truncate to 1 hour
38
  max_duration_sec = 3600
39
  waveform = waveform[:sr * max_duration_sec]
40
 
 
45
 
46
  partial_text = ""
47
  for chunk in chunks:
 
48
  inputs = processor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
49
  input_features = inputs.input_features.to(device)
50
 
 
51
  with torch.no_grad():
52
  predicted_ids = model.generate(
53
+ input_features,
54
  forced_decoder_ids=forced_decoder_ids
55
  )
56
 
 
57
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
58
  partial_text += transcription + "\n"
59
 
 
66
  if not audio_url:
67
  return jsonify({"error": "Missing 'audio_url' in request"}), 400
68
 
 
69
  transcription = transcribe_audio(audio_url)
70
  return jsonify({"transcription": transcription})
71
 
72
  if __name__ == '__main__':
 
73
  app.run(host="0.0.0.0", port=7860)