EladSpamson commited on
Commit
ba4bba3
·
verified ·
1 Parent(s): 8be8710

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -29
app.py CHANGED
@@ -10,50 +10,86 @@ model = WhisperForConditionalGeneration.from_pretrained(model_id)
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model.to(device)
12
 
13
- forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")
 
 
 
 
 
 
 
 
 
14
 
15
  def transcribe_audio(audio_file):
16
- """Process only the first 30 seconds of an audio file and return text."""
 
 
 
 
 
 
 
17
  waveform, sr = librosa.load(audio_file, sr=16000)
18
 
19
- # Limit to first 30 seconds
20
- time_limit_s = 30
21
  if len(waveform) > sr * time_limit_s:
22
  waveform = waveform[: sr * time_limit_s]
23
 
24
- # Preprocess
25
- inputs = processor(
26
- waveform,
27
- sampling_rate=16000,
28
- return_tensors="pt",
29
- padding="longest",
30
- return_attention_mask=True
31
- )
32
- input_features = inputs.input_features.to(device)
33
- attention_mask = inputs.attention_mask.to(device)
34
-
35
- # Transcribe
36
- with torch.no_grad():
37
- predicted_ids = model.generate(
38
- input_features,
39
- attention_mask=attention_mask,
40
- max_new_tokens=444,
41
- do_sample=False,
42
- forced_decoder_ids=forced_decoder_ids
 
 
 
 
 
 
 
 
 
43
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # Decode and return text
46
- text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
47
- return text
48
 
49
- # Expose API endpoint for Make.com
50
  demo = gr.Interface(
51
  fn=transcribe_audio,
52
  inputs=gr.Audio(type="filepath"),
53
  outputs="text",
54
  title="Hebrew Whisper API",
55
- api_name="transcribe" # This enables API access
56
  )
57
 
58
- # Run on Hugging Face Spaces
59
  demo.launch()
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model.to(device)
12
 
13
+ # Force Hebrew transcription
14
+ forced_decoder_ids = processor.get_decoder_prompt_ids(
15
+ language="he",
16
+ task="transcribe"
17
+ )
18
+
19
+ stop_processing = False
20
+ def stop():
21
+ global stop_processing
22
+ stop_processing = True
23
 
24
  def transcribe_audio(audio_file):
25
+ """
26
+ Process the first 30 seconds of the audio, in 5-second chunks.
27
+ Return full transcription as a single output.
28
+ """
29
+ global stop_processing
30
+ stop_processing = False
31
+
32
+ # Load at 16kHz
33
  waveform, sr = librosa.load(audio_file, sr=16000)
34
 
35
+ # Truncate to the first 30 seconds
36
+ time_limit_s = 6000
37
  if len(waveform) > sr * time_limit_s:
38
  waveform = waveform[: sr * time_limit_s]
39
 
40
+ # Also limit if total is over 60 min (safety)
41
+ max_audio_sec = 60 * 60
42
+ if len(waveform) > sr * max_audio_sec:
43
+ waveform = waveform[: sr * max_audio_sec]
44
+
45
+ # Split into 5s chunks
46
+ chunk_duration_s = 25
47
+ chunk_size = sr * chunk_duration_s
48
+ chunks = []
49
+ for start_idx in range(0, len(waveform), chunk_size):
50
+ chunk = waveform[start_idx : start_idx + chunk_size]
51
+ if len(chunk) < sr * 1:
52
+ continue
53
+ chunks.append(chunk)
54
+
55
+ partial_text = ""
56
+
57
+ # Transcribe chunk by chunk
58
+ for chunk in chunks:
59
+ if stop_processing:
60
+ return "⚠️ Stopped by User ⚠️"
61
+
62
+ inputs = processor(
63
+ chunk,
64
+ sampling_rate=16000,
65
+ return_tensors="pt",
66
+ padding="longest",
67
+ return_attention_mask=True
68
  )
69
+ input_features = inputs.input_features.to(device)
70
+ attention_mask = inputs.attention_mask.to(device)
71
+
72
+ with torch.no_grad():
73
+ predicted_ids = model.generate(
74
+ input_features,
75
+ attention_mask=attention_mask,
76
+ max_new_tokens=444,
77
+ do_sample=False,
78
+ forced_decoder_ids=forced_decoder_ids
79
+ )
80
+
81
+ text_chunk = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
82
+ partial_text += text_chunk + "\n"
83
 
84
+ return partial_text.strip()
 
 
85
 
86
+ # Build Gradio UI with API support
87
  demo = gr.Interface(
88
  fn=transcribe_audio,
89
  inputs=gr.Audio(type="filepath"),
90
  outputs="text",
91
  title="Hebrew Whisper API",
92
+ api_name="transcribe" # Enables API access for Make.com
93
  )
94
 
 
95
  demo.launch()