File size: 2,590 Bytes
fdb39f4
 
aefce6b
fdb39f4
 
a6aa54b
 
fdb39f4
67a7670
 
886af50
e6cfbd7
 
221d07a
3d55353
886af50
 
aefce6b
5b89128
221d07a
 
 
 
 
0348b75
aefce6b
66d0ca2
 
aa43ea6
aefce6b
aa43ea6
 
 
 
 
aefce6b
aa43ea6
040da24
aefce6b
 
aa43ea6
5b89128
66d0ca2
 
040da24
aa43ea6
 
 
66d0ca2
aa43ea6
 
aefce6b
aa43ea6
66d0ca2
 
 
 
040da24
fdb39f4
aa43ea6
 
 
 
66d0ca2
aa43ea6
 
66d0ca2
aa43ea6
 
 
66d0ca2
67a7670
aefce6b
67a7670
 
 
 
 
 
aa43ea6
66d0ca2
aa43ea6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os

# Environment variables to avoid permission issues
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
os.environ["XDG_CACHE_HOME"] = "/tmp"

from flask import Flask, request, jsonify, Response
import json
import requests
import torch
import librosa
from transformers import WhisperProcessor, WhisperForConditionalGeneration

app = Flask(__name__)

# Use your custom Hebrew Whisper model (example: ivrit-ai/whisper-large-v3-turbo)
model_id = "ivrit-ai/whisper-large-v3-turbo"
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Force Hebrew to skip auto-detect
forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")

def transcribe_audio(audio_url):
    # 1) Download audio file
    response = requests.get(audio_url)
    audio_path = "/tmp/temp_audio.wav"
    with open(audio_path, "wb") as f:
        f.write(response.content)

    # 2) Load audio with librosa
    waveform, sr = librosa.load(audio_path, sr=16000)

    # 3) Limit to 1 hour
    waveform = waveform[: sr * 3600]

    # 4) Split into 25-second chunks
    chunk_sec = 25
    chunk_size = sr * chunk_sec
    chunks = [waveform[i : i + chunk_size] for i in range(0, len(waveform), chunk_size)]

    partial_text = ""
    for chunk in chunks:
        inputs = processor(chunk, sampling_rate=sr, return_tensors="pt", padding=True)
        input_features = inputs.input_features.to(device)

        # Generate forced-Hebrew transcription
        with torch.no_grad():
            predicted_ids = model.generate(
                input_features,
                forced_decoder_ids=forced_decoder_ids
            )

        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        partial_text += transcription + "\n"

    return partial_text.strip()

@app.route("/transcribe", methods=["POST"])
def transcribe_endpoint():
    data = request.get_json()
    audio_url = data.get("audio_url")
    if not audio_url:
        return jsonify({"error": "Missing 'audio_url' in request"}), 400

    text = transcribe_audio(audio_url)

    # Return Hebrew characters directly
    payload = {"Transcription": text}
    return Response(
        json.dumps(payload, ensure_ascii=False),
        status=200,
        mimetype="application/json; charset=utf-8"
    )

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)