EladSpamson commited on
Commit
d2be6df
·
verified ·
1 Parent(s): aba5a96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -13
app.py CHANGED
@@ -1,12 +1,13 @@
 
1
  import requests
2
  import torch
3
  import librosa
4
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
5
- from flask import Flask, request, jsonify
6
 
7
  app = Flask(__name__)
8
 
9
- model_id = "openai/whisper-large-v3"
 
10
  processor = WhisperProcessor.from_pretrained(model_id)
11
  model = WhisperForConditionalGeneration.from_pretrained(model_id)
12
 
@@ -33,27 +34,20 @@ def transcribe_audio(audio_url):
33
  inputs = processor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
34
  input_features = inputs.input_features.to(device)
35
 
36
- with torch.no_grad():
37
- predicted_ids = model.generate(
38
- input_features,
39
- forced_decoder_ids=forced_decoder_ids
40
- )
41
-
42
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
43
  partial_text += transcription + "\n"
44
 
45
- return partial_text.strip()
46
 
47
  @app.route('/transcribe', methods=['POST'])
48
  def transcribe_endpoint():
49
  data = request.get_json()
50
- audio_url = data.get('audio_url')
51
- if not audio_url:
52
- return jsonify({"error": "Missing 'audio_url' in request"}), 400
53
 
54
  transcription = transcribe_audio(audio_url)
55
 
56
- return jsonify({"transcription": transcription})
57
 
58
  if __name__ == '__main__':
59
  app.run(host="0.0.0.0", port=8080)
 
1
+ from flask import Flask, request, jsonify
2
  import requests
3
  import torch
4
  import librosa
5
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
 
6
 
7
  app = Flask(__name__)
8
 
9
+ # Temporarily using smaller model for faster testing
10
+ model_id = "openai/whisper-base"
11
  processor = WhisperProcessor.from_pretrained(model_id)
12
  model = WhisperForConditionalGeneration.from_pretrained(model_id)
13
 
 
34
  inputs = processor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
35
  input_features = inputs.input_features.to(device)
36
 
37
+ predicted_ids = model.generate(input_features)
 
 
 
 
 
38
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
39
  partial_text += transcription + "\n"
40
 
41
+ return partial_text
42
 
43
  @app.route('/transcribe', methods=['POST'])
44
  def transcribe_endpoint():
45
  data = request.get_json()
46
+ audio_url = data['audio_url']
 
 
47
 
48
  transcription = transcribe_audio(audio_url)
49
 
50
+ return {"transcription": transcription}
51
 
52
  if __name__ == '__main__':
53
  app.run(host="0.0.0.0", port=8080)