EladSpamson commited on
Commit
fdb39f4
·
verified ·
1 Parent(s): 7f353a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -20
app.py CHANGED
@@ -1,16 +1,18 @@
 
 
 
 
 
 
1
  from flask import Flask, request, jsonify
2
  import requests
3
  import torch
4
  import librosa
5
- import os
6
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
7
 
8
- # Explicitly set writable cache directory
9
- os.environ['HF_HOME'] = '/tmp/hf_cache'
10
-
11
  app = Flask(__name__)
12
 
13
- # Temporarily using smaller model for faster testing
14
  model_id = "openai/whisper-base"
15
  processor = WhisperProcessor.from_pretrained(model_id)
16
  model = WhisperForConditionalGeneration.from_pretrained(model_id)
@@ -21,45 +23,39 @@ model.to(device)
21
  forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")
22
 
23
  def transcribe_audio(audio_url):
 
24
  response = requests.get(audio_url)
25
  audio_path = "/tmp/temp_audio.wav"
26
  with open(audio_path, "wb") as f:
27
  f.write(response.content)
28
 
29
- # Load audio
30
  waveform, sr = librosa.load(audio_path, sr=16000)
31
 
32
- # Safety limit (1 hour)
33
  max_duration_sec = 3600
34
  waveform = waveform[:sr * max_duration_sec]
35
 
36
- # Split into smaller chunks
37
  chunk_duration_sec = 25
38
  chunk_size = sr * chunk_duration_sec
39
  chunks = [waveform[i : i + chunk_size] for i in range(0, len(waveform), chunk_size)]
40
 
41
  partial_text = ""
42
  for chunk in chunks:
43
- inputs = processor(
44
- chunk,
45
- sampling_rate=16000,
46
- return_tensors="pt",
47
- padding=True
48
- )
49
  input_features = inputs.input_features.to(device)
50
 
51
- # Generate text
52
  with torch.no_grad():
53
  predicted_ids = model.generate(
54
  input_features,
55
  forced_decoder_ids=forced_decoder_ids
56
  )
57
 
58
- transcription = processor.batch_decode(
59
- predicted_ids,
60
- skip_special_tokens=True
61
- )[0]
62
-
63
  partial_text += transcription + "\n"
64
 
65
  return partial_text.strip()
@@ -71,8 +67,10 @@ def transcribe_endpoint():
71
  if not audio_url:
72
  return jsonify({"error": "Missing 'audio_url' in request"}), 400
73
 
 
74
  transcription = transcribe_audio(audio_url)
75
  return jsonify({"transcription": transcription})
76
 
77
  if __name__ == '__main__':
 
78
  app.run(host="0.0.0.0", port=7860)
 
1
+ import os
2
+
3
+ # Ensure environment variables are set before loading transformers
4
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
5
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
6
+
7
  from flask import Flask, request, jsonify
8
  import requests
9
  import torch
10
  import librosa
 
11
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
12
 
 
 
 
13
  app = Flask(__name__)
14
 
15
+ # Use a smaller model for faster CPU loading
16
  model_id = "openai/whisper-base"
17
  processor = WhisperProcessor.from_pretrained(model_id)
18
  model = WhisperForConditionalGeneration.from_pretrained(model_id)
 
23
  forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")
24
 
25
  def transcribe_audio(audio_url):
26
+ # 1) Download the file to /tmp
27
  response = requests.get(audio_url)
28
  audio_path = "/tmp/temp_audio.wav"
29
  with open(audio_path, "wb") as f:
30
  f.write(response.content)
31
 
32
+ # 2) Load audio with librosa
33
  waveform, sr = librosa.load(audio_path, sr=16000)
34
 
35
+ # 3) Truncate to 1 hour max
36
  max_duration_sec = 3600
37
  waveform = waveform[:sr * max_duration_sec]
38
 
39
+ # 4) Split into 25-second chunks
40
  chunk_duration_sec = 25
41
  chunk_size = sr * chunk_duration_sec
42
  chunks = [waveform[i : i + chunk_size] for i in range(0, len(waveform), chunk_size)]
43
 
44
  partial_text = ""
45
  for chunk in chunks:
46
+ # Preprocess chunk
47
+ inputs = processor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
 
 
 
 
48
  input_features = inputs.input_features.to(device)
49
 
50
+ # Transcribe chunk
51
  with torch.no_grad():
52
  predicted_ids = model.generate(
53
  input_features,
54
  forced_decoder_ids=forced_decoder_ids
55
  )
56
 
57
+ # Convert IDs back to text
58
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
 
 
 
59
  partial_text += transcription + "\n"
60
 
61
  return partial_text.strip()
 
67
  if not audio_url:
68
  return jsonify({"error": "Missing 'audio_url' in request"}), 400
69
 
70
+ # Perform transcription
71
  transcription = transcribe_audio(audio_url)
72
  return jsonify({"transcription": transcription})
73
 
74
  if __name__ == '__main__':
75
+ # Run the Flask app on port 7860
76
  app.run(host="0.0.0.0", port=7860)