rrevo's picture
prev chunk
b932faf
raw
history blame
1.31 kB
import os
import pickle
from typing import Union
from fastapi import Request
import torch
from transformers import pipeline
from fastapi import FastAPI
from contextlib import asynccontextmanager
DEVICE = os.getenv('DEVICE', 'mps')
ATTN_IMPLEMENTATION = os.getenv('ATTN_IMPLEMENTATION', "sdpa")
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.transcribe_pipeline = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3",
torch_dtype=torch.float16 if ATTN_IMPLEMENTATION == "sdpa" else torch.bfloat16,
device=DEVICE,
model_kwargs={"attn_implementation": ATTN_IMPLEMENTATION},
)
app.state.transcribe_pipeline.model.to('cuda')
yield
app = FastAPI(lifespan=lifespan)
@app.get("/")
def read_root():
return {"status": "ok"}
@app.post("/transcribe")
async def transcribe(request: Request):
body = await request.body()
audio_chunk = pickle.loads(body)
outputs = app.state.transcribe_pipeline(
audio_chunk,
chunk_length_s=30,
batch_size=24,
generate_kwargs={
'task': 'transcribe',
'language': 'english'
},
return_timestamps='word'
)
text = outputs["text"].strip()
return {"transcribe": text, "outputs": outputs}