Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,13 @@
|
|
|
|
1 |
import requests
|
2 |
import torch
|
3 |
import librosa
|
4 |
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
5 |
-
from flask import Flask, request, jsonify
|
6 |
|
7 |
app = Flask(__name__)
|
8 |
|
9 |
-
|
|
|
10 |
processor = WhisperProcessor.from_pretrained(model_id)
|
11 |
model = WhisperForConditionalGeneration.from_pretrained(model_id)
|
12 |
|
@@ -33,27 +34,20 @@ def transcribe_audio(audio_url):
|
|
33 |
inputs = processor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
|
34 |
input_features = inputs.input_features.to(device)
|
35 |
|
36 |
-
|
37 |
-
predicted_ids = model.generate(
|
38 |
-
input_features,
|
39 |
-
forced_decoder_ids=forced_decoder_ids
|
40 |
-
)
|
41 |
-
|
42 |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
43 |
partial_text += transcription + "\n"
|
44 |
|
45 |
-
return partial_text
|
46 |
|
47 |
@app.route('/transcribe', methods=['POST'])
|
48 |
def transcribe_endpoint():
|
49 |
data = request.get_json()
|
50 |
-
audio_url = data
|
51 |
-
if not audio_url:
|
52 |
-
return jsonify({"error": "Missing 'audio_url' in request"}), 400
|
53 |
|
54 |
transcription = transcribe_audio(audio_url)
|
55 |
|
56 |
-
return
|
57 |
|
58 |
if __name__ == '__main__':
|
59 |
app.run(host="0.0.0.0", port=8080)
|
|
|
1 |
+
from flask import Flask, request, jsonify
|
2 |
import requests
|
3 |
import torch
|
4 |
import librosa
|
5 |
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
|
|
6 |
|
7 |
app = Flask(__name__)
|
8 |
|
9 |
+
# Temporarily using smaller model for faster testing
|
10 |
+
model_id = "openai/whisper-base"
|
11 |
processor = WhisperProcessor.from_pretrained(model_id)
|
12 |
model = WhisperForConditionalGeneration.from_pretrained(model_id)
|
13 |
|
|
|
34 |
inputs = processor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
|
35 |
input_features = inputs.input_features.to(device)
|
36 |
|
37 |
+
predicted_ids = model.generate(input_features)
|
|
|
|
|
|
|
|
|
|
|
38 |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
39 |
partial_text += transcription + "\n"
|
40 |
|
41 |
+
return partial_text
|
42 |
|
43 |
@app.route('/transcribe', methods=['POST'])
|
44 |
def transcribe_endpoint():
|
45 |
data = request.get_json()
|
46 |
+
audio_url = data['audio_url']
|
|
|
|
|
47 |
|
48 |
transcription = transcribe_audio(audio_url)
|
49 |
|
50 |
+
return {"transcription": transcription}
|
51 |
|
52 |
if __name__ == '__main__':
|
53 |
app.run(host="0.0.0.0", port=8080)
|