File size: 2,429 Bytes
fdb39f4
 
 
 
 
 
d2be6df
886af50
e6cfbd7
 
221d07a
3d55353
886af50
 
fdb39f4
10c5d51
221d07a
 
 
 
 
0348b75
f34f2ca
ba4bba3
aa43ea6
fdb39f4
aa43ea6
 
 
 
 
fdb39f4
aa43ea6
040da24
fdb39f4
aa43ea6
 
 
fdb39f4
aa43ea6
 
040da24
aa43ea6
 
 
fdb39f4
 
aa43ea6
 
fdb39f4
aa43ea6
040da24
 
 
 
 
fdb39f4
 
aa43ea6
 
 
 
 
 
 
 
 
 
 
fdb39f4
aa43ea6
 
 
 
fdb39f4
aa43ea6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os

# Ensure environment variables are set before loading transformers
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"

from flask import Flask, request, jsonify
import requests
import torch
import librosa
from transformers import WhisperProcessor, WhisperForConditionalGeneration

app = Flask(__name__)

# Use a smaller model for faster CPU loading
model_id = "openai/whisper-base"
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")

def transcribe_audio(audio_url):
    # 1) Download the file to /tmp
    response = requests.get(audio_url)
    audio_path = "/tmp/temp_audio.wav"
    with open(audio_path, "wb") as f:
        f.write(response.content)

    # 2) Load audio with librosa
    waveform, sr = librosa.load(audio_path, sr=16000)

    # 3) Truncate to 1 hour max
    max_duration_sec = 3600
    waveform = waveform[:sr * max_duration_sec]

    # 4) Split into 25-second chunks
    chunk_duration_sec = 25
    chunk_size = sr * chunk_duration_sec
    chunks = [waveform[i : i + chunk_size] for i in range(0, len(waveform), chunk_size)]

    partial_text = ""
    for chunk in chunks:
        # Preprocess chunk
        inputs = processor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
        input_features = inputs.input_features.to(device)

        # Transcribe chunk
        with torch.no_grad():
            predicted_ids = model.generate(
                input_features, 
                forced_decoder_ids=forced_decoder_ids
            )

        # Convert IDs back to text
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        partial_text += transcription + "\n"

    return partial_text.strip()

@app.route('/transcribe', methods=['POST'])
def transcribe_endpoint():
    data = request.get_json()
    audio_url = data.get('audio_url')
    if not audio_url:
        return jsonify({"error": "Missing 'audio_url' in request"}), 400

    # Perform transcription
    transcription = transcribe_audio(audio_url)
    return jsonify({"transcription": transcription})

if __name__ == '__main__':
    # Run the Flask app on port 7860
    app.run(host="0.0.0.0", port=7860)