rrevo commited on
Commit
2a8b0bb
·
1 Parent(s): 190e978
Files changed (1) hide show
  1. server/src/main.py +3 -5
server/src/main.py CHANGED
@@ -11,19 +11,17 @@ from contextlib import asynccontextmanager
11
  DEVICE = os.getenv('DEVICE', 'mps')
12
  ATTN_IMPLEMENTATION = os.getenv('ATTN_IMPLEMENTATION', "sdpa")
13
 
14
- transcribe_pipeline = None
15
-
16
 
17
  @asynccontextmanager
18
  async def lifespan(app: FastAPI):
19
- transcribe_pipeline = pipeline(
20
  "automatic-speech-recognition",
21
  model="openai/whisper-large-v3",
22
  torch_dtype=torch.float16 if ATTN_IMPLEMENTATION == "sdpa" else torch.bfloat16,
23
  device=DEVICE,
24
  model_kwargs={"attn_implementation": ATTN_IMPLEMENTATION},
25
  )
26
- transcribe_pipeline.model.to('cuda')
27
  yield
28
 
29
  app = FastAPI(lifespan=lifespan)
@@ -40,7 +38,7 @@ def read_root():
40
  async def transcribe(request: Request):
41
  body = await request.body()
42
  audio_chunk = pickle.loads(body)
43
- outputs = transcribe_pipeline(
44
  audio_chunk,
45
  chunk_length_s=30,
46
  batch_size=24,
 
11
  DEVICE = os.getenv('DEVICE', 'mps')
12
  ATTN_IMPLEMENTATION = os.getenv('ATTN_IMPLEMENTATION', "sdpa")
13
 
 
 
14
 
15
  @asynccontextmanager
16
  async def lifespan(app: FastAPI):
17
+ app.state.transcribe_pipeline = pipeline(
18
  "automatic-speech-recognition",
19
  model="openai/whisper-large-v3",
20
  torch_dtype=torch.float16 if ATTN_IMPLEMENTATION == "sdpa" else torch.bfloat16,
21
  device=DEVICE,
22
  model_kwargs={"attn_implementation": ATTN_IMPLEMENTATION},
23
  )
24
+ app.state.transcribe_pipeline.model.to('cuda')
25
  yield
26
 
27
  app = FastAPI(lifespan=lifespan)
 
38
  async def transcribe(request: Request):
39
  body = await request.body()
40
  audio_chunk = pickle.loads(body)
41
+ outputs = app.state.transcribe_pipeline(
42
  audio_chunk,
43
  chunk_length_s=30,
44
  batch_size=24,