EladSpamson commited on
Commit
bceee03
·
verified ·
1 Parent(s): d2be6df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -34,20 +34,24 @@ def transcribe_audio(audio_url):
34
  inputs = processor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
35
  input_features = inputs.input_features.to(device)
36
 
37
- predicted_ids = model.generate(input_features)
 
 
38
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
39
  partial_text += transcription + "\n"
40
 
41
- return partial_text
42
 
43
  @app.route('/transcribe', methods=['POST'])
44
  def transcribe_endpoint():
45
  data = request.get_json()
46
- audio_url = data['audio_url']
 
 
47
 
48
  transcription = transcribe_audio(audio_url)
49
 
50
- return {"transcription": transcription}
51
 
52
  if __name__ == '__main__':
53
- app.run(host="0.0.0.0", port=8080)
 
34
  inputs = processor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
35
  input_features = inputs.input_features.to(device)
36
 
37
+ with torch.no_grad():
38
+ predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
39
+
40
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
41
  partial_text += transcription + "\n"
42
 
43
+ return partial_text.strip()
44
 
45
  @app.route('/transcribe', methods=['POST'])
46
  def transcribe_endpoint():
47
  data = request.get_json()
48
+ audio_url = data.get('audio_url')
49
+ if not audio_url:
50
+ return jsonify({"error": "Missing 'audio_url' in request"}), 400
51
 
52
  transcription = transcribe_audio(audio_url)
53
 
54
+ return jsonify({"transcription": transcription})
55
 
56
  if __name__ == '__main__':
57
+ app.run(host="0.0.0.0", port=7860)