rrevo commited on
Commit
a405ea7
·
1 Parent(s): 4843865
Files changed (1) hide show
  1. server/src/main.py +19 -9
server/src/main.py CHANGED
@@ -6,31 +6,41 @@ import torch
6
  from transformers import pipeline
7
 
8
  from fastapi import FastAPI
 
9
 
10
  app = FastAPI()
11
 
12
  DEVICE = os.getenv('DEVICE', 'mps')
13
  ATTN_IMPLEMENTATION = os.getenv('ATTN_IMPLEMENTATION', "sdpa")
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  @app.get("/")
16
  def read_root():
17
  return {"status": "ok"}
18
 
19
 
20
- TRANSCRIBE_PIPELINE = pipeline(
21
- "automatic-speech-recognition",
22
- model="openai/whisper-large-v3",
23
- torch_dtype=torch.float16 if ATTN_IMPLEMENTATION == "sdpa" else torch.bfloat16,
24
- device=DEVICE,
25
- model_kwargs={"attn_implementation": ATTN_IMPLEMENTATION},
26
- )
27
-
28
 
29
  @app.post("/transcribe")
30
  async def transcribe(request: Request):
31
  body = await request.body()
32
  audio_chunk = pickle.loads(body)
33
- outputs = TRANSCRIBE_PIPELINE(
34
  audio_chunk,
35
  chunk_length_s=30,
36
  batch_size=24,
 
6
  from transformers import pipeline
7
 
8
  from fastapi import FastAPI
9
+ from contextlib import asynccontextmanager
10
 
11
  app = FastAPI()
12
 
13
  DEVICE = os.getenv('DEVICE', 'mps')
14
  ATTN_IMPLEMENTATION = os.getenv('ATTN_IMPLEMENTATION', "sdpa")
15
 
16
+ transcribe_pipeline = None
17
+
18
+
19
+ @asynccontextmanager
20
+ async def lifespan(app: FastAPI):
21
+ transcribe_pipeline = pipeline(
22
+ "automatic-speech-recognition",
23
+ model="openai/whisper-large-v3",
24
+ torch_dtype=torch.float16 if ATTN_IMPLEMENTATION == "sdpa" else torch.bfloat16,
25
+ device=DEVICE,
26
+ model_kwargs={"attn_implementation": ATTN_IMPLEMENTATION},
27
+ )
28
+ transcribe_pipeline.model.to('cuda')
29
+ yield
30
+
31
+
32
+
33
  @app.get("/")
34
  def read_root():
35
  return {"status": "ok"}
36
 
37
 
 
 
 
 
 
 
 
 
38
 
39
  @app.post("/transcribe")
40
  async def transcribe(request: Request):
41
  body = await request.body()
42
  audio_chunk = pickle.loads(body)
43
+ outputs = transcribe_pipeline(
44
  audio_chunk,
45
  chunk_length_s=30,
46
  batch_size=24,