tezuesh commited on
Commit
18eda2e
·
verified ·
1 Parent(s): 3a00ea9

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +14 -31
server.py CHANGED
@@ -1,7 +1,6 @@
1
  from fastapi import FastAPI, HTTPException
2
  import numpy as np
3
  import torch
4
- from pydantic import BaseModel
5
  import base64
6
  import io
7
  import os
@@ -10,6 +9,20 @@ from pathlib import Path
10
  from inference import InferenceRecipe
11
  from fastapi.middleware.cors import CORSMiddleware
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
@@ -41,7 +54,6 @@ INITIALIZATION_STATUS = {
41
  # Global model instance
42
  model = None
43
 
44
-
45
  def initialize_model():
46
  """Initialize the model with correct path resolution"""
47
  global model, INITIALIZATION_STATUS
@@ -93,35 +105,6 @@ def health_check():
93
 
94
  return status
95
 
96
- # @app.post("/api/v1/inference")
97
- # async def inference(request: AudioRequest) -> AudioResponse:
98
- # """Run inference on audio input"""
99
- # if not INITIALIZATION_STATUS["model_loaded"]:
100
- # raise HTTPException(
101
- # status_code=503,
102
- # detail=f"Model not ready. Status: {INITIALIZATION_STATUS}"
103
- # )
104
-
105
- # try:
106
- # # Decode audio from base64
107
- # audio_bytes = base64.b64decode(request.audio_data)
108
- # audio_array = np.load(io.BytesIO(audio_bytes))
109
-
110
- # # Run inference
111
- # result = model.inference(audio_array, request.sample_rate)
112
-
113
- # # Encode output audio
114
- # buffer = io.BytesIO()
115
- # np.save(buffer, result['audio'])
116
- # audio_b64 = base64.b64encode(buffer.getvalue()).decode()
117
-
118
- # return AudioResponse(
119
- # audio_data=audio_b64,
120
- # text=result.get("text", "")
121
- # )
122
- # except Exception as e:
123
- # logger.error(f"Inference failed: {str(e)}")
124
- # raise HTTPException(status_code=500, detail=str(e))
125
  @app.post("/api/v1/inference")
126
  async def inference(request: AudioRequest) -> AudioResponse:
127
  """Run inference with enhanced error handling and logging"""
 
1
  from fastapi import FastAPI, HTTPException
2
  import numpy as np
3
  import torch
 
4
  import base64
5
  import io
6
  import os
 
9
  from inference import InferenceRecipe
10
  from fastapi.middleware.cors import CORSMiddleware
11
 
12
+ # Add these imports and configurations at the top
13
+ import torch._inductor
14
+ import torch._dynamo
15
+
16
+ # Configure Inductor/Triton cache and fallback behavior
17
+ os.environ["TRITON_CACHE_DIR"] = "/tmp/triton_cache"
18
+ os.environ["TORCH_INDUCTOR_CACHE_DIR"] = "/tmp/torch_cache"
19
+ torch._inductor.config.suppress_errors = True
20
+ torch._dynamo.config.suppress_errors = True
21
+
22
+ # Create cache directories with correct permissions
23
+ os.makedirs("/tmp/triton_cache", exist_ok=True)
24
+ os.makedirs("/tmp/torch_cache", exist_ok=True)
25
+
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
28
 
 
54
  # Global model instance
55
  model = None
56
 
 
57
  def initialize_model():
58
  """Initialize the model with correct path resolution"""
59
  global model, INITIALIZATION_STATUS
 
105
 
106
  return status
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  @app.post("/api/v1/inference")
109
  async def inference(request: AudioRequest) -> AudioResponse:
110
  """Run inference with enhanced error handling and logging"""