EladSpamson commited on
Commit
fd4a773
·
verified ·
1 Parent(s): 605203a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -39
app.py CHANGED
@@ -4,85 +4,109 @@ import librosa
4
  import numpy as np
5
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
6
 
7
- # Load the Faster Whisper model
8
- model_id = "ivrit-ai/faster-whisper-v2-d4" # Switch to a smaller, faster model
9
- processor = WhisperProcessor.from_pretrained(model_id)
10
- model = WhisperForConditionalGeneration.from_pretrained(model_id)
11
-
12
- # Force GPU usage (if available)
 
 
 
 
 
 
 
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  model.to(device)
15
 
16
- # Global variable to control stopping
17
- stop_processing = False
 
 
18
 
19
- # Function to stop transcription
20
  def stop():
 
 
 
 
21
  global stop_processing
22
- stop_processing = True # This will break transcription
 
23
 
24
- # Function to process long audio in chunks
 
 
25
  def transcribe(audio):
 
 
 
 
26
  global stop_processing
27
- stop_processing = False # Reset stop flag when new transcription starts
28
 
29
- # Load the audio file and convert to 16kHz
30
  waveform, sr = librosa.load(audio, sr=16000)
31
-
32
- # Set chunk size (~2 min per chunk)
33
- chunk_duration = 2 * 60 # 2 minutes (120 seconds)
34
- max_audio_length = 60 * 60 # 60 minutes
35
- chunks = []
36
-
37
- # Ensure audio doesn't exceed 60 minutes
38
  if len(waveform) > sr * max_audio_length:
39
  waveform = waveform[: sr * max_audio_length]
40
 
41
- # Split audio into ~2-minute chunks
 
 
42
  for i in range(0, len(waveform), sr * chunk_duration):
43
- if stop_processing:
44
  return "⚠️ Transcription Stopped by User ⚠️"
45
 
46
  chunk = waveform[i : i + sr * chunk_duration]
47
- if len(chunk) < sr * 2: # Skip chunks shorter than 2 seconds
 
48
  continue
49
  chunks.append(chunk)
50
 
51
- # Process each chunk and transcribe
52
  transcriptions = []
53
  for chunk in chunks:
54
- if stop_processing:
55
  return "⚠️ Transcription Stopped by User ⚠️"
56
 
57
- input_features = processor(chunk, sampling_rate=16000, return_tensors="pt", language="he").input_features.to(device)
 
58
 
59
  with torch.no_grad():
60
  predicted_ids = model.generate(
61
- input_features,
62
- max_new_tokens=444, # FIXED: Prevents exceeding model limit
63
- do_sample=False # Ensures stable, faster transcription
64
  )
65
 
66
- # Decode and store transcription
67
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
68
- transcriptions.append(transcription)
 
 
 
69
 
70
- # Join all chunk transcriptions into one
71
- full_transcription = " ".join(transcriptions)
72
- return full_transcription
73
 
74
- # Create the Gradio Interface
 
 
75
  with gr.Blocks() as iface:
76
- gr.Markdown("# Hebrew Speech-to-Text (Faster Whisper)")
77
 
 
78
  audio_input = gr.Audio(type="filepath", label="Upload Hebrew Audio")
79
  output_text = gr.Textbox(label="Transcription Output")
80
 
 
81
  start_btn = gr.Button("Start Transcription")
82
  stop_btn = gr.Button("Stop Processing", variant="stop")
83
 
 
84
  start_btn.click(transcribe, inputs=audio_input, outputs=output_text)
85
- stop_btn.click(stop) # Calls the stop function when clicked
86
 
87
- # Launch the Gradio app
88
  iface.launch()
 
4
  import numpy as np
5
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
6
 
7
+ # ------------------------------
8
+ # 1. Load the Model & Processor
9
+ # ------------------------------
10
+ model_id = "ivrit-ai/faster-whisper-v2-d4" # Replace with a verified HF model if needed, e.g. "openai/whisper-large-v2"
11
+
12
+ try:
13
+ processor = WhisperProcessor.from_pretrained(model_id)
14
+ model = WhisperForConditionalGeneration.from_pretrained(model_id)
15
+ except OSError as e:
16
+ raise ValueError(
17
+ f"Unable to load model or tokenizer from '{model_id}'. "
18
+ "Double-check that the model ID is valid on Hugging Face Hub."
19
+ ) from e
20
+
21
+ # Force GPU usage if available
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  model.to(device)
24
 
25
+ # ---------------------------
26
+ # 2. Global Stop Flag
27
+ # ---------------------------
28
+ stop_processing = False
29
 
 
30
  def stop():
31
+ """
32
+ Callback to set a global stop flag, allowing the user to interrupt
33
+ transcription mid-way through processing.
34
+ """
35
  global stop_processing
36
+ stop_processing = True
37
+
38
 
39
+ # -------------------------------------------
40
+ # 3. Transcription Function (with Chunking)
41
+ # -------------------------------------------
42
  def transcribe(audio):
43
+ """
44
+ Transcribes Hebrew speech from an uploaded audio file.
45
+ Splits long audio into 2-minute chunks to handle large files (up to 60 min).
46
+ """
47
  global stop_processing
48
+ stop_processing = False # Reset at start
49
 
50
+ # --- A) Load Audio & Limit to 60 Minutes
51
  waveform, sr = librosa.load(audio, sr=16000)
52
+ max_audio_length = 60 * 60 # 60 minutes in seconds
 
 
 
 
 
 
53
  if len(waveform) > sr * max_audio_length:
54
  waveform = waveform[: sr * max_audio_length]
55
 
56
+ # --- B) Split Audio into ~2-minute Chunks
57
+ chunk_duration = 2 * 60 # 2 minutes (120 seconds)
58
+ chunks = []
59
  for i in range(0, len(waveform), sr * chunk_duration):
60
+ if stop_processing:
61
  return "⚠️ Transcription Stopped by User ⚠️"
62
 
63
  chunk = waveform[i : i + sr * chunk_duration]
64
+ # Optional: skip very short chunks (<2 seconds)
65
+ if len(chunk) < sr * 2:
66
  continue
67
  chunks.append(chunk)
68
 
69
+ # --- C) Process Each Chunk with Whisper
70
  transcriptions = []
71
  for chunk in chunks:
72
+ if stop_processing:
73
  return "⚠️ Transcription Stopped by User ⚠️"
74
 
75
+ # Convert the chunk to Whisper input features
76
+ inputs = processor(chunk, sampling_rate=16000, return_tensors="pt", language="he").input_features.to(device)
77
 
78
  with torch.no_grad():
79
  predicted_ids = model.generate(
80
+ inputs,
81
+ max_new_tokens=444, # Prevent exceeding model’s token limit
82
+ do_sample=False, # Stable transcription (disable random sampling)
83
  )
84
 
85
+ # Decode tokens to text
86
+ text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
87
+ transcriptions.append(text)
88
+
89
+ # --- D) Combine All Chunk Transcriptions
90
+ return " ".join(transcriptions)
91
 
 
 
 
92
 
93
+ # ------------------------
94
+ # 4. Build Gradio Interface
95
+ # ------------------------
96
  with gr.Blocks() as iface:
97
+ gr.Markdown("## Hebrew Speech-to-Text (Faster Whisper)")
98
 
99
+ # Inputs/Outputs
100
  audio_input = gr.Audio(type="filepath", label="Upload Hebrew Audio")
101
  output_text = gr.Textbox(label="Transcription Output")
102
 
103
+ # Buttons
104
  start_btn = gr.Button("Start Transcription")
105
  stop_btn = gr.Button("Stop Processing", variant="stop")
106
 
107
+ # Click Actions
108
  start_btn.click(transcribe, inputs=audio_input, outputs=output_text)
109
+ stop_btn.click(stop)
110
 
111
+ # Launch the Gradio App
112
  iface.launch()