File size: 1,696 Bytes
62b3774
71f6464
db56cf5
71f6464
 
 
209a86b
db56cf5
 
a405ea7
db56cf5
62b3774
 
db56cf5
a405ea7
 
 
209a86b
 
 
 
 
 
 
 
 
 
 
2a8b0bb
a405ea7
209a86b
 
 
 
 
a405ea7
 
 
2b62b02
 
a405ea7
 
db56cf5
 
 
 
 
71f6464
1ba4a0c
71f6464
 
 
2a8b0bb
71f6464
 
 
 
 
 
 
4aedc6e
71f6464
 
b932faf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import pickle
from typing import Union
from fastapi import Request
import torch
from transformers import pipeline
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

from fastapi import FastAPI
from contextlib import asynccontextmanager

DEVICE = os.getenv('DEVICE', 'mps')
ATTN_IMPLEMENTATION = os.getenv('ATTN_IMPLEMENTATION', "sdpa")


@asynccontextmanager
async def lifespan(app: FastAPI):
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    model_id = "openai/whisper-large-v3"
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
    )
    model.to(device)

    processor = AutoProcessor.from_pretrained(model_id)

    app.state.transcribe_pipeline = pipeline(
        "automatic-speech-recognition",
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        torch_dtype=torch_dtype,
        device=device,
    )
    yield

app = FastAPI(lifespan=lifespan)



@app.get("/")
def read_root():
    return {"status": "ok"}



@app.post("/transcribe")
async def transcribe(request: Request):
    body = await request.body()
    audio_chunk = pickle.loads(body)
    outputs = app.state.transcribe_pipeline(
        audio_chunk,
        chunk_length_s=30,
        batch_size=24,
        generate_kwargs={
            'task': 'transcribe',
            'language': 'english'
        },
        # return_timestamps='word'
    )
    text = outputs["text"].strip()
    return {"transcribe": text, "outputs": outputs}