File size: 2,278 Bytes
fdb39f4
 
1c7c059
fdb39f4
 
a6aa54b
 
fdb39f4
d2be6df
886af50
e6cfbd7
 
221d07a
3d55353
886af50
 
1c7c059
10c5d51
221d07a
 
 
 
 
0348b75
aa43ea6
a6aa54b
aa43ea6
 
 
 
 
1c7c059
aa43ea6
040da24
1c7c059
aa43ea6
 
 
1c7c059
aa43ea6
 
040da24
aa43ea6
 
 
fdb39f4
aa43ea6
 
1c7c059
aa43ea6
1c7c059
040da24
fdb39f4
aa43ea6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os

# Ensure environment variables are set before Transformers are imported
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
os.environ["XDG_CACHE_HOME"] = "/tmp"

from flask import Flask, request, jsonify
import requests
import torch
import librosa
from transformers import WhisperProcessor, WhisperForConditionalGeneration

app = Flask(__name__)

# Using a smaller model for faster CPU loading
model_id = "openai/whisper-base"
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

def transcribe_audio(audio_url):
    # 1) Download audio file to /tmp
    response = requests.get(audio_url)
    audio_path = "/tmp/temp_audio.wav"
    with open(audio_path, "wb") as f:
        f.write(response.content)

    # 2) Load audio with librosa
    waveform, sr = librosa.load(audio_path, sr=16000)

    # 3) Optional safety limit (1 hour)
    max_duration_sec = 3600
    waveform = waveform[:sr * max_duration_sec]

    # 4) Split into smaller chunks (25s)
    chunk_duration_sec = 25
    chunk_size = sr * chunk_duration_sec
    chunks = [waveform[i : i + chunk_size] for i in range(0, len(waveform), chunk_size)]

    partial_text = ""
    for chunk in chunks:
        inputs = processor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
        input_features = inputs.input_features.to(device)

        # **No** forced_decoder_ids => Whisper auto-detects language
        with torch.no_grad():
            predicted_ids = model.generate(input_features)

        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        partial_text += transcription + "\n"

    return partial_text.strip()

@app.route('/transcribe', methods=['POST'])
def transcribe_endpoint():
    data = request.get_json()
    audio_url = data.get('audio_url')
    if not audio_url:
        return jsonify({"error": "Missing 'audio_url' in request"}), 400

    transcription = transcribe_audio(audio_url)
    return jsonify({"transcription": transcription})

if __name__ == '__main__':
    app.run(host="0.0.0.0", port=7860)