File size: 2,267 Bytes
d2be6df
886af50
e6cfbd7
 
84b9b91
221d07a
3d55353
aa43ea6
 
84b9b91
886af50
 
aa43ea6
10c5d51
221d07a
 
 
 
 
0348b75
f34f2ca
ba4bba3
aa43ea6
 
 
 
 
 
040da24
aa43ea6
040da24
 
aa43ea6
 
 
040da24
aa43ea6
 
040da24
aa43ea6
 
 
040da24
 
 
 
 
 
aa43ea6
 
040da24
aa43ea6
040da24
 
 
 
 
 
 
 
 
aa43ea6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from flask import Flask, request, jsonify
import requests
import torch
import librosa
import os
from transformers import WhisperProcessor, WhisperForConditionalGeneration

# Explicitly set writable cache directory
os.environ['HF_HOME'] = '/tmp/hf_cache'

app = Flask(__name__)

# Temporarily using smaller model for faster testing
model_id = "openai/whisper-base"
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")

def transcribe_audio(audio_url):
    response = requests.get(audio_url)
    audio_path = "/tmp/temp_audio.wav"
    with open(audio_path, "wb") as f:
        f.write(response.content)

    # Load audio
    waveform, sr = librosa.load(audio_path, sr=16000)

    # Safety limit (1 hour)
    max_duration_sec = 3600
    waveform = waveform[:sr * max_duration_sec]

    # Split into smaller chunks
    chunk_duration_sec = 25
    chunk_size = sr * chunk_duration_sec
    chunks = [waveform[i : i + chunk_size] for i in range(0, len(waveform), chunk_size)]

    partial_text = ""
    for chunk in chunks:
        inputs = processor(
            chunk, 
            sampling_rate=16000, 
            return_tensors="pt", 
            padding=True
        )
        input_features = inputs.input_features.to(device)

        # Generate text
        with torch.no_grad():
            predicted_ids = model.generate(
                input_features, 
                forced_decoder_ids=forced_decoder_ids
            )

        transcription = processor.batch_decode(
            predicted_ids, 
            skip_special_tokens=True
        )[0]

        partial_text += transcription + "\n"

    return partial_text.strip()

@app.route('/transcribe', methods=['POST'])
def transcribe_endpoint():
    data = request.get_json()
    audio_url = data.get('audio_url')
    if not audio_url:
        return jsonify({"error": "Missing 'audio_url' in request"}), 400

    transcription = transcribe_audio(audio_url)
    return jsonify({"transcription": transcription})

if __name__ == '__main__':
    app.run(host="0.0.0.0", port=7860)