Exceedea / app.py
EladSpamson's picture
Update app.py
aefce6b verified
raw
history blame
2.59 kB
import os
# Environment variables to avoid permission issues
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
os.environ["XDG_CACHE_HOME"] = "/tmp"
from flask import Flask, request, jsonify, Response
import json
import requests
import torch
import librosa
from transformers import WhisperProcessor, WhisperForConditionalGeneration
app = Flask(__name__)
# Use your custom Hebrew Whisper model (example: ivrit-ai/whisper-large-v3-turbo)
model_id = "ivrit-ai/whisper-large-v3-turbo"
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Force Hebrew to skip auto-detect
forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")
def transcribe_audio(audio_url):
# 1) Download audio file
response = requests.get(audio_url)
audio_path = "/tmp/temp_audio.wav"
with open(audio_path, "wb") as f:
f.write(response.content)
# 2) Load audio with librosa
waveform, sr = librosa.load(audio_path, sr=16000)
# 3) Limit to 1 hour
waveform = waveform[: sr * 3600]
# 4) Split into 25-second chunks
chunk_sec = 25
chunk_size = sr * chunk_sec
chunks = [waveform[i : i + chunk_size] for i in range(0, len(waveform), chunk_size)]
partial_text = ""
for chunk in chunks:
inputs = processor(chunk, sampling_rate=sr, return_tensors="pt", padding=True)
input_features = inputs.input_features.to(device)
# Generate forced-Hebrew transcription
with torch.no_grad():
predicted_ids = model.generate(
input_features,
forced_decoder_ids=forced_decoder_ids
)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
partial_text += transcription + "\n"
return partial_text.strip()
@app.route("/transcribe", methods=["POST"])
def transcribe_endpoint():
data = request.get_json()
audio_url = data.get("audio_url")
if not audio_url:
return jsonify({"error": "Missing 'audio_url' in request"}), 400
text = transcribe_audio(audio_url)
# Return Hebrew characters directly
payload = {"Transcription": text}
return Response(
json.dumps(payload, ensure_ascii=False),
status=200,
mimetype="application/json; charset=utf-8"
)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)