Spaces:
Runtime error
Runtime error
File size: 2,267 Bytes
d2be6df 886af50 e6cfbd7 84b9b91 221d07a 3d55353 aa43ea6 84b9b91 886af50 aa43ea6 10c5d51 221d07a 0348b75 f34f2ca ba4bba3 aa43ea6 040da24 aa43ea6 040da24 aa43ea6 040da24 aa43ea6 040da24 aa43ea6 040da24 aa43ea6 040da24 aa43ea6 040da24 aa43ea6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from flask import Flask, request, jsonify
import requests
import torch
import librosa
import os
from transformers import WhisperProcessor, WhisperForConditionalGeneration
# Explicitly set writable cache directory
os.environ['HF_HOME'] = '/tmp/hf_cache'
app = Flask(__name__)
# Temporarily using smaller model for faster testing
model_id = "openai/whisper-base"
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")
def transcribe_audio(audio_url):
response = requests.get(audio_url)
audio_path = "/tmp/temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(response.content)
# Load audio
waveform, sr = librosa.load(audio_path, sr=16000)
# Safety limit (1 hour)
max_duration_sec = 3600
waveform = waveform[:sr * max_duration_sec]
# Split into smaller chunks
chunk_duration_sec = 25
chunk_size = sr * chunk_duration_sec
chunks = [waveform[i : i + chunk_size] for i in range(0, len(waveform), chunk_size)]
partial_text = ""
for chunk in chunks:
inputs = processor(
chunk,
sampling_rate=16000,
return_tensors="pt",
padding=True
)
input_features = inputs.input_features.to(device)
# Generate text
with torch.no_grad():
predicted_ids = model.generate(
input_features,
forced_decoder_ids=forced_decoder_ids
)
transcription = processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
partial_text += transcription + "\n"
return partial_text.strip()
@app.route('/transcribe', methods=['POST'])
def transcribe_endpoint():
data = request.get_json()
audio_url = data.get('audio_url')
if not audio_url:
return jsonify({"error": "Missing 'audio_url' in request"}), 400
transcription = transcribe_audio(audio_url)
return jsonify({"transcription": transcription})
if __name__ == '__main__':
app.run(host="0.0.0.0", port=7860)
|