EladSpamson commited on
Commit
67a7670
·
verified ·
1 Parent(s): 66d0ca2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -11
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import os
2
 
3
- # Set environment variables so HF uses /tmp for caching
4
  os.environ["HF_HOME"] = "/tmp/hf_cache"
5
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
6
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
7
  os.environ["XDG_CACHE_HOME"] = "/tmp"
8
 
9
- from flask import Flask, request, jsonify
 
10
  import requests
11
  import torch
12
  import librosa
@@ -14,7 +15,8 @@ from transformers import WhisperProcessor, WhisperForConditionalGeneration
14
 
15
  app = Flask(__name__)
16
 
17
- # Use a multilingual model capable of Hebrew (e.g. whisper-base)
 
18
  model_id = "openai/whisper-base"
19
  processor = WhisperProcessor.from_pretrained(model_id)
20
  model = WhisperForConditionalGeneration.from_pretrained(model_id)
@@ -22,36 +24,36 @@ model = WhisperForConditionalGeneration.from_pretrained(model_id)
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  model.to(device)
24
 
25
- # Force Hebrew transcription tokens so no auto-detect occurs
26
  forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")
27
 
28
  def transcribe_audio(audio_url):
29
- # 1) Download audio file to /tmp
30
  response = requests.get(audio_url)
31
  audio_path = "/tmp/temp_audio.wav"
32
  with open(audio_path, "wb") as f:
33
  f.write(response.content)
34
 
35
- # 2) Load with librosa
36
  waveform, sr = librosa.load(audio_path, sr=16000)
37
 
38
- # 3) Optional: limit to 1 hour
39
  max_sec = 3600
40
  waveform = waveform[: sr * max_sec]
41
 
42
- # 4) Split into 25-second chunks (or pick any chunk size)
43
  chunk_sec = 25
44
  chunk_size = sr * chunk_sec
45
  chunks = [waveform[i : i + chunk_size] for i in range(0, len(waveform), chunk_size)]
46
 
47
  partial_text = ""
48
  for chunk in chunks:
49
- # Preprocess chunk to mel
50
  inputs = processor(chunk, sampling_rate=sr, return_tensors="pt", padding=True)
51
  input_features = inputs.input_features.to(device)
52
 
 
53
  with torch.no_grad():
54
- # Force Hebrew so no meltdown on short audio
55
  predicted_ids = model.generate(
56
  input_features,
57
  forced_decoder_ids=forced_decoder_ids
@@ -69,8 +71,16 @@ def transcribe_endpoint():
69
  if not audio_url:
70
  return jsonify({"error": "Missing 'audio_url' in request"}), 400
71
 
 
72
  text = transcribe_audio(audio_url)
73
- return jsonify({"transcription": text})
 
 
 
 
 
 
 
74
 
75
  if __name__ == "__main__":
76
  app.run(host="0.0.0.0", port=7860)
 
1
  import os
2
 
3
+ # Must set environment variables before importing Transformers
4
  os.environ["HF_HOME"] = "/tmp/hf_cache"
5
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
6
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
7
  os.environ["XDG_CACHE_HOME"] = "/tmp"
8
 
9
+ from flask import Flask, request, jsonify, Response
10
+ import json
11
  import requests
12
  import torch
13
  import librosa
 
15
 
16
  app = Flask(__name__)
17
 
18
+ # Choose a multilingual Whisper model that includes Hebrew.
19
+ # For CPU usage, 'openai/whisper-base' or 'openai/whisper-tiny' are typical.
20
  model_id = "openai/whisper-base"
21
  processor = WhisperProcessor.from_pretrained(model_id)
22
  model = WhisperForConditionalGeneration.from_pretrained(model_id)
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
  model.to(device)
26
 
27
+ # Force Hebrew to avoid short-audio meltdown with auto-detect
28
  forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")
29
 
30
  def transcribe_audio(audio_url):
31
+ # 1) Download audio to /tmp
32
  response = requests.get(audio_url)
33
  audio_path = "/tmp/temp_audio.wav"
34
  with open(audio_path, "wb") as f:
35
  f.write(response.content)
36
 
37
+ # 2) Load audio with librosa
38
  waveform, sr = librosa.load(audio_path, sr=16000)
39
 
40
+ # 3) Limit up to 1 hour for stability
41
  max_sec = 3600
42
  waveform = waveform[: sr * max_sec]
43
 
44
+ # 4) Chunk the audio in 25-second intervals
45
  chunk_sec = 25
46
  chunk_size = sr * chunk_sec
47
  chunks = [waveform[i : i + chunk_size] for i in range(0, len(waveform), chunk_size)]
48
 
49
  partial_text = ""
50
  for chunk in chunks:
51
+ # Preprocess chunk mel spectrogram
52
  inputs = processor(chunk, sampling_rate=sr, return_tensors="pt", padding=True)
53
  input_features = inputs.input_features.to(device)
54
 
55
+ # Force Hebrew to skip auto-detect logic
56
  with torch.no_grad():
 
57
  predicted_ids = model.generate(
58
  input_features,
59
  forced_decoder_ids=forced_decoder_ids
 
71
  if not audio_url:
72
  return jsonify({"error": "Missing 'audio_url' in request"}), 400
73
 
74
+ # Perform forced-Hebrew transcription
75
  text = transcribe_audio(audio_url)
76
+
77
+ # Return JSON with no ASCII escaping (ensures real Hebrew chars)
78
+ payload = {"Transcription": text}
79
+ return Response(
80
+ json.dumps(payload, ensure_ascii=False),
81
+ status=200,
82
+ mimetype="application/json; charset=utf-8"
83
+ )
84
 
85
  if __name__ == "__main__":
86
  app.run(host="0.0.0.0", port=7860)