File size: 1,899 Bytes
d2be6df
886af50
e6cfbd7
 
221d07a
3d55353
886af50
 
d2be6df
 
221d07a
 
 
 
 
0348b75
f34f2ca
ba4bba3
886af50
 
 
 
 
 
 
 
4cca673
886af50
 
e2ba5da
ba4bba3
 
e2ba5da
 
ba4bba3
 
bceee03
 
 
886af50
 
8be8710
bceee03
2a0a17e
886af50
 
 
bceee03
 
 
2a0a17e
886af50
2a0a17e
bceee03
4cca673
886af50
bceee03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from flask import Flask, request, jsonify
import requests
import torch
import librosa
from transformers import WhisperProcessor, WhisperForConditionalGeneration

app = Flask(__name__)

# Temporarily using smaller model for faster testing
model_id = "openai/whisper-base"
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

forced_decoder_ids = processor.get_decoder_prompt_ids(language="he", task="transcribe")

def transcribe_audio(audio_url):
    response = requests.get(audio_url)
    with open("temp_audio.wav", "wb") as f:
        f.write(response.content)

    waveform, sr = librosa.load("temp_audio.wav", sr=16000)
    max_duration_sec = 3600
    waveform = waveform[:sr * max_duration_sec]

    chunk_duration_sec = 25
    chunk_size = sr * chunk_duration_sec
    chunks = [waveform[i:i + chunk_size] for i in range(0, len(waveform), chunk_size)]

    partial_text = ""
    for chunk in chunks:
        inputs = processor(chunk, sampling_rate=16000, return_tensors="pt", padding=True)
        input_features = inputs.input_features.to(device)

        with torch.no_grad():
            predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)

        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        partial_text += transcription + "\n"

    return partial_text.strip()

@app.route('/transcribe', methods=['POST'])
def transcribe_endpoint():
    data = request.get_json()
    audio_url = data.get('audio_url')
    if not audio_url:
        return jsonify({"error": "Missing 'audio_url' in request"}), 400

    transcription = transcribe_audio(audio_url)

    return jsonify({"transcription": transcription})

if __name__ == '__main__':
    app.run(host="0.0.0.0", port=7860)