EladSpamson commited on
Commit
4cca673
·
verified ·
1 Parent(s): adc3da1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -46
app.py CHANGED
@@ -10,53 +10,75 @@ model = WhisperForConditionalGeneration.from_pretrained(model_id)
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model.to(device)
12
 
13
- forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")
 
 
 
 
 
 
 
 
 
14
 
15
- def transcribe_long(audio_file):
16
- # 1) Load full audio (limit to 60 minutes)
 
 
 
 
 
 
17
  waveform, sr = librosa.load(audio_file, sr=16000)
18
- if len(waveform) > sr * 3600:
19
- waveform = waveform[: sr * 3600]
20
-
21
- # 2) Split into ~2min chunks
22
- chunk_sec = 120
23
- chunk_size = sr * chunk_sec
24
- all_text = []
25
- for start in range(0, len(waveform), chunk_size):
26
- chunk = waveform[start : start + chunk_size]
27
- # skip chunks <2s if you want
28
- if len(chunk) < sr * 2:
29
- continue
30
-
31
- # 3) Encode with attention mask
32
- inputs = processor(
33
- chunk,
34
- sampling_rate=16000,
35
- return_tensors="pt",
36
- padding="longest",
37
- return_attention_mask=True
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  )
39
- input_features = inputs.input_features.to(device)
40
- attention_mask = inputs.attention_mask.to(device)
41
-
42
- # 4) Generate
43
- with torch.no_grad():
44
- predicted_ids = model.generate(
45
- input_features,
46
- attention_mask=attention_mask,
47
- max_new_tokens=444,
48
- do_sample=False,
49
- forced_decoder_ids=forced_decoder_ids
50
- )
51
- text_chunk = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
52
- all_text.append(text_chunk)
53
-
54
- return " ".join(all_text)
55
-
56
- demo = gr.Interface(
57
- fn=transcribe_long,
58
- inputs=gr.Audio(type="filepath", label="Upload Audio (unlimited)"),
59
- outputs="text",
60
- title="Chunked Whisper (No Token Overflow)"
61
- )
62
  demo.launch()
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model.to(device)
12
 
13
+ # Force Hebrew (transcribe) decoding:
14
+ forced_decoder_ids = processor.get_decoder_prompt_ids(
15
+ language="he",
16
+ task="transcribe"
17
+ )
18
+
19
+ stop_processing = False
20
+ def stop():
21
+ global stop_processing
22
+ stop_processing = True
23
 
24
+ def transcribe_first_chunk(audio_file):
25
+ """
26
+ Transcribe only the first 'time_limit_s' seconds of the uploaded audio.
27
+ """
28
+ global stop_processing
29
+ stop_processing = False
30
+
31
+ # A) Load at 16kHz
32
  waveform, sr = librosa.load(audio_file, sr=16000)
33
+
34
+ # B) Truncate to the first 4 minutes
35
+ time_limit_s = 4 * 60 # 4 minutes = 240 seconds
36
+ if len(waveform) > sr * time_limit_s:
37
+ waveform = waveform[: sr * time_limit_s]
38
+
39
+ # Also limit if total is over 60 min (safety)
40
+ max_audio_sec = 60 * 60
41
+ if len(waveform) > sr * max_audio_sec:
42
+ waveform = waveform[: sr * max_audio_sec]
43
+
44
+ # C) Preprocess: get attention mask
45
+ inputs = processor(
46
+ waveform,
47
+ sampling_rate=16000,
48
+ return_tensors="pt",
49
+ padding="longest",
50
+ return_attention_mask=True
51
+ )
52
+ input_features = inputs.input_features.to(device)
53
+ attention_mask = inputs.attention_mask.to(device)
54
+
55
+ if stop_processing:
56
+ return "⚠️ Stopped by User ⚠️"
57
+
58
+ # D) Generate
59
+ with torch.no_grad():
60
+ predicted_ids = model.generate(
61
+ input_features,
62
+ attention_mask=attention_mask,
63
+ max_new_tokens=444, # keep total under 448 tokens
64
+ do_sample=False, # deterministic
65
+ forced_decoder_ids=forced_decoder_ids # ensure Hebrew transcription
66
  )
67
+
68
+ # E) Decode
69
+ text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
70
+ return text
71
+
72
+ with gr.Blocks() as demo:
73
+ gr.Markdown("## Hebrew Whisper (Only First 4 Minutes)")
74
+
75
+ audio_input = gr.Audio(type="filepath", label="Upload Audio (Truncate to 4min)")
76
+ output_text = gr.Textbox(label="Partial Transcription")
77
+
78
+ start_btn = gr.Button("Start Transcription")
79
+ stop_btn = gr.Button("Stop Processing", variant="stop")
80
+
81
+ start_btn.click(transcribe_first_chunk, inputs=audio_input, outputs=output_text)
82
+ stop_btn.click(stop)
83
+
 
 
 
 
 
 
84
  demo.launch()