m1k3wn commited on
Commit
1ec19be
·
verified ·
1 Parent(s): 925be36

Update app.py

Browse files

major refactor to utilise increased vCPUs. Improves memory cleanup etc

Files changed (1) hide show
  1. app.py +258 -59
app.py CHANGED
@@ -1,3 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
@@ -8,6 +216,10 @@ import os
8
  import sys
9
  import traceback
10
  from functools import lru_cache
 
 
 
 
11
 
12
  # Initialize FastAPI
13
  app = FastAPI()
@@ -58,52 +270,39 @@ DEFAULT_GENERATION_CONFIGS = {
58
 
59
  class ModelManager:
60
  _instances: ClassVar[Dict[str, tuple]] = {}
 
61
 
62
  @classmethod
63
- def get_model_and_tokenizer(cls, model_name: str):
64
- if model_name not in cls._instances:
65
- try:
66
- model_path = MODELS[model_name]
67
- logger.info(f"Loading tokenizer for {model_name}")
68
- tokenizer = T5Tokenizer.from_pretrained(
69
- model_path,
70
- token=HF_TOKEN,
71
- local_files_only=False,
72
- return_special_tokens_mask=True
73
- )
74
-
75
- logger.info(f"Loading model {model_name}")
76
- # Check if accelerate is available
77
  try:
78
- import accelerate
79
- logger.info("Accelerate package found, using device_map='auto'")
80
- model = T5ForConditionalGeneration.from_pretrained(
81
  model_path,
82
  token=HF_TOKEN,
83
- local_files_only=False,
84
- device_map="auto",
85
- low_cpu_mem_usage=True,
86
- torch_dtype=torch.float32
87
  )
88
- except ImportError:
89
- logger.warning("Accelerate package not found, falling back to CPU")
90
  model = T5ForConditionalGeneration.from_pretrained(
91
  model_path,
92
  token=HF_TOKEN,
93
- local_files_only=False
 
 
94
  )
95
- model = model.cpu()
96
-
97
- cls._instances[model_name] = (model, tokenizer)
98
- logger.info(f"Successfully loaded {model_name}")
99
- except Exception as e:
100
- logger.error(f"Error loading {model_name}: {str(e)}")
101
- raise HTTPException(
102
- status_code=500,
103
- detail=f"Failed to load model {model_name}: {str(e)}"
104
- )
105
-
106
- return cls._instances[model_name]
107
 
108
  class PredictionRequest(BaseModel):
109
  inputs: str
@@ -133,10 +332,8 @@ async def version():
133
 
134
  @app.get("/health")
135
  async def health():
136
- # More comprehensive health check
137
  try:
138
- # Try to load at least one model to verify functionality
139
- ModelManager.get_model_and_tokenizer("nidra-v1")
140
  return {
141
  "status": "healthy",
142
  "loaded_models": list(ModelManager._instances.keys())
@@ -149,22 +346,17 @@ async def health():
149
  }
150
 
151
  @app.post("/predict", response_model=PredictionResponse)
152
- async def predict(request: PredictionRequest):
153
  try:
154
- # Validate model
155
  if request.model not in MODELS:
156
  raise HTTPException(
157
  status_code=400,
158
  detail=f"Invalid model. Available models: {list(MODELS.keys())}"
159
  )
160
 
161
- # Get cached model and tokenizer
162
- model, tokenizer = ModelManager.get_model_and_tokenizer(request.model)
163
-
164
- # Get generation parameters
165
  generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
166
 
167
- # Try to load model's saved generation config
168
  try:
169
  model_generation_config = model.generation_config
170
  generation_params.update({
@@ -174,33 +366,36 @@ async def predict(request: PredictionRequest):
174
  except Exception as config_load_error:
175
  logger.warning(f"Using default generation config: {config_load_error}")
176
 
177
- # Override with request-specific parameters
178
  if request.parameters:
179
  generation_params.update(request.parameters)
180
 
181
  logger.debug(f"Final generation parameters: {generation_params}")
182
 
183
- # Prepare input
184
  full_input = "Interpret this dream: " + request.inputs
185
  inputs = tokenizer(
186
  full_input,
187
  return_tensors="pt",
188
  truncation=True,
189
  max_length=512,
190
- padding=True
191
- ).to(model.device) # Ensure inputs are on same device as model
192
-
193
- outputs = model.generate(
194
- **inputs,
195
- max_time=90.0, # 90 second timeout
196
- **{k: v for k, v in generation_params.items() if k in [
197
- 'max_length', 'min_length', 'do_sample', 'temperature',
198
- 'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
199
- 'repetition_penalty', 'early_stopping'
200
- ]}
201
  )
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
204
 
205
  return PredictionResponse(
206
  generated_text=result,
@@ -212,6 +407,10 @@ async def predict(request: PredictionRequest):
212
  logger.error(error_msg)
213
  raise HTTPException(status_code=500, detail=error_msg)
214
 
 
 
 
 
215
  if __name__ == "__main__":
216
  import uvicorn
217
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ ```pythonimport torch
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
5
+ from typing import Optional, Dict, Any, ClassVar
6
+ import logging
7
+ import os
8
+ import sys
9
+ import traceback
10
+ from functools import lru_cache
11
+ import gc
12
+ import asyncio
13
+ from fastapi import BackgroundTasks
14
+ import psutil
15
+
16
+ # Initialize FastAPI
17
+ app = FastAPI()
18
+
19
+ # Debugging logging with detailed formatting
20
+ # logging.basicConfig(
21
+ # level=logging.DEBUG,
22
+ # format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
23
+ # )
24
+ # logger = logging.getLogger(__name__)
25
+
26
+ # Get HF token
27
+ HF_TOKEN = os.environ.get("HF_TOKEN")
28
+ if not HF_TOKEN:
29
+ logger.warning("No HF_TOKEN found in environment variables")
30
+
31
+ MODELS = {
32
+ "nidra-v1": "m1k3wn/nidra-v1",
33
+ "nidra-v2": "m1k3wn/nidra-v2"
34
+ }
35
+
36
+ DEFAULT_GENERATION_CONFIGS = {
37
+ "nidra-v1": {
38
+ "max_length": 300,
39
+ "min_length": 150,
40
+ "num_beams": 8,
41
+ "temperature": 0.55,
42
+ "do_sample": True,
43
+ "top_p": 0.95,
44
+ "repetition_penalty": 4.5,
45
+ "no_repeat_ngram_size": 4,
46
+ "early_stopping": True,
47
+ "length_penalty": 1.2,
48
+ },
49
+ "nidra-v2": {
50
+ "max_length": 300,
51
+ "min_length": 150,
52
+ "num_beams": 8,
53
+ "temperature": 0.4,
54
+ "do_sample": True,
55
+ "top_p": 0.95,
56
+ "repetition_penalty": 3.5,
57
+ "no_repeat_ngram_size": 4,
58
+ "early_stopping": True,
59
+ "length_penalty": 1.2,
60
+ }
61
+ }
62
+
63
+ class ModelManager:
64
+ _instances: ClassVar[Dict[str, tuple]] = {}
65
+ _lock = asyncio.Lock() # Add lock for thread safety
66
+
67
+ @classmethod
68
+ async def get_model_and_tokenizer(cls, model_name: str):
69
+ async with cls._lock:
70
+ if model_name not in cls._instances:
71
+ try:
72
+ model_path = MODELS[model_name]
73
+ tokenizer = T5Tokenizer.from_pretrained(
74
+ model_path,
75
+ token=HF_TOKEN,
76
+ local_files_only=True # Cache after first load
77
+ )
78
+
79
+ model = T5ForConditionalGeneration.from_pretrained(
80
+ model_path,
81
+ token=HF_TOKEN,
82
+ local_files_only=True,
83
+ low_cpu_mem_usage=True,
84
+ torch_dtype=torch.float32
85
+ )
86
+
87
+ # Enable parallel processing
88
+ model.eval()
89
+ torch.set_num_threads(8) # Use all CPU cores
90
+
91
+ cls._instances[model_name] = (model, tokenizer)
92
+
93
+ except Exception as e:
94
+ logger.error(f"Error loading {model_name}: {str(e)}")
95
+ raise
96
+
97
+ return cls._instances[model_name]
98
+
99
+ class PredictionRequest(BaseModel):
100
+ inputs: str
101
+ model: str = "nidra-v1"
102
+ parameters: Optional[Dict[str, Any]] = None
103
+
104
+ class PredictionResponse(BaseModel):
105
+ generated_text: str
106
+ selected_model: str # Changed from model_used to avoid namespace conflict
107
+
108
+ @app.get("/debug/memory")
109
+ async def memory_usage():
110
+ process = psutil.Process()
111
+ memory_info = process.memory_info()
112
+ return {
113
+ "memory_used_mb": memory_info.rss / 1024 / 1024,
114
+ "memory_percent": process.memory_percent(),
115
+ "cpu_percent": process.cpu_percent()
116
+ }
117
+
118
+ @app.get("/version")
119
+ async def version():
120
+ return {
121
+ "python_version": sys.version,
122
+ "models_available": list(MODELS.keys())
123
+ }
124
+
125
+ @app.get("/health")
126
+ async def health():
127
+ try:
128
+ await ModelManager.get_model_and_tokenizer("nidra-v1")
129
+ return {
130
+ "status": "healthy",
131
+ "loaded_models": list(ModelManager._instances.keys())
132
+ }
133
+ except Exception as e:
134
+ logger.error(f"Health check failed: {str(e)}")
135
+ return {
136
+ "status": "unhealthy",
137
+ "error": str(e)
138
+ }
139
+
140
+ @app.post("/predict", response_model=PredictionResponse)
141
+ async def predict(request: PredictionRequest, background_tasks: BackgroundTasks):
142
+ try:
143
+ if request.model not in MODELS:
144
+ raise HTTPException(
145
+ status_code=400,
146
+ detail=f"Invalid model. Available models: {list(MODELS.keys())}"
147
+ )
148
+
149
+ model, tokenizer = await ModelManager.get_model_and_tokenizer(request.model)
150
+ generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
151
+
152
+ try:
153
+ model_generation_config = model.generation_config
154
+ generation_params.update({
155
+ k: v for k, v in model_generation_config.to_dict().items()
156
+ if v is not None
157
+ })
158
+ except Exception as config_load_error:
159
+ logger.warning(f"Using default generation config: {config_load_error}")
160
+
161
+ if request.parameters:
162
+ generation_params.update(request.parameters)
163
+
164
+ logger.debug(f"Final generation parameters: {generation_params}")
165
+
166
+ full_input = "Interpret this dream: " + request.inputs
167
+ inputs = tokenizer(
168
+ full_input,
169
+ return_tensors="pt",
170
+ truncation=True,
171
+ max_length=512,
172
+ padding=True,
173
+ return_attention_mask=True
174
+ )
175
+
176
+ async def generate():
177
+ return model.generate(
178
+ **inputs,
179
+ **{k: v for k, v in generation_params.items() if k in [
180
+ 'max_length', 'min_length', 'do_sample', 'temperature',
181
+ 'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
182
+ 'repetition_penalty', 'early_stopping'
183
+ ]}
184
+ )
185
+
186
+ with torch.inference_mode():
187
+ outputs = await asyncio.wait_for(generate(), timeout=70.0)
188
+
189
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
190
+ background_tasks.add_task(cleanup_memory)
191
+
192
+ return PredictionResponse(
193
+ generated_text=result,
194
+ selected_model=request.model
195
+ )
196
+
197
+ except Exception as e:
198
+ error_msg = f"Error during prediction: {str(e)}\n{traceback.format_exc()}"
199
+ logger.error(error_msg)
200
+ raise HTTPException(status_code=500, detail=error_msg)
201
+
202
+ def cleanup_memory():
203
+ gc.collect()
204
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
205
+
206
+ if __name__ == "__main__":
207
+ import uvicorn
208
+ uvicorn.run(app, host="0.0.0.0", port=7860)
209
  import torch
210
  from fastapi import FastAPI, HTTPException
211
  from pydantic import BaseModel
 
216
  import sys
217
  import traceback
218
  from functools import lru_cache
219
+ import gc
220
+ import asyncio
221
+ from fastapi import BackgroundTasks
222
+ import psutil
223
 
224
  # Initialize FastAPI
225
  app = FastAPI()
 
270
 
271
  class ModelManager:
272
  _instances: ClassVar[Dict[str, tuple]] = {}
273
+ _lock = asyncio.Lock() # Add lock for thread safety
274
 
275
  @classmethod
276
+ async def get_model_and_tokenizer(cls, model_name: str):
277
+ async with cls._lock:
278
+ if model_name not in cls._instances:
 
 
 
 
 
 
 
 
 
 
 
279
  try:
280
+ model_path = MODELS[model_name]
281
+ tokenizer = T5Tokenizer.from_pretrained(
 
282
  model_path,
283
  token=HF_TOKEN,
284
+ local_files_only=True # Cache after first load
 
 
 
285
  )
286
+
 
287
  model = T5ForConditionalGeneration.from_pretrained(
288
  model_path,
289
  token=HF_TOKEN,
290
+ local_files_only=True,
291
+ low_cpu_mem_usage=True,
292
+ torch_dtype=torch.float32
293
  )
294
+
295
+ # Enable parallel processing
296
+ model.eval()
297
+ torch.set_num_threads(8) # Use all CPU cores
298
+
299
+ cls._instances[model_name] = (model, tokenizer)
300
+
301
+ except Exception as e:
302
+ logger.error(f"Error loading {model_name}: {str(e)}")
303
+ raise
304
+
305
+ return cls._instances[model_name]
306
 
307
  class PredictionRequest(BaseModel):
308
  inputs: str
 
332
 
333
  @app.get("/health")
334
  async def health():
 
335
  try:
336
+ await ModelManager.get_model_and_tokenizer("nidra-v1")
 
337
  return {
338
  "status": "healthy",
339
  "loaded_models": list(ModelManager._instances.keys())
 
346
  }
347
 
348
  @app.post("/predict", response_model=PredictionResponse)
349
+ async def predict(request: PredictionRequest, background_tasks: BackgroundTasks):
350
  try:
 
351
  if request.model not in MODELS:
352
  raise HTTPException(
353
  status_code=400,
354
  detail=f"Invalid model. Available models: {list(MODELS.keys())}"
355
  )
356
 
357
+ model, tokenizer = await ModelManager.get_model_and_tokenizer(request.model)
 
 
 
358
  generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
359
 
 
360
  try:
361
  model_generation_config = model.generation_config
362
  generation_params.update({
 
366
  except Exception as config_load_error:
367
  logger.warning(f"Using default generation config: {config_load_error}")
368
 
 
369
  if request.parameters:
370
  generation_params.update(request.parameters)
371
 
372
  logger.debug(f"Final generation parameters: {generation_params}")
373
 
 
374
  full_input = "Interpret this dream: " + request.inputs
375
  inputs = tokenizer(
376
  full_input,
377
  return_tensors="pt",
378
  truncation=True,
379
  max_length=512,
380
+ padding=True,
381
+ return_attention_mask=True
 
 
 
 
 
 
 
 
 
382
  )
383
 
384
+ async def generate():
385
+ return model.generate(
386
+ **inputs,
387
+ **{k: v for k, v in generation_params.items() if k in [
388
+ 'max_length', 'min_length', 'do_sample', 'temperature',
389
+ 'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
390
+ 'repetition_penalty', 'early_stopping'
391
+ ]}
392
+ )
393
+
394
+ with torch.inference_mode():
395
+ outputs = await asyncio.wait_for(generate(), timeout=70.0)
396
+
397
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
398
+ background_tasks.add_task(cleanup_memory)
399
 
400
  return PredictionResponse(
401
  generated_text=result,
 
407
  logger.error(error_msg)
408
  raise HTTPException(status_code=500, detail=error_msg)
409
 
410
+ def cleanup_memory():
411
+ gc.collect()
412
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
413
+
414
  if __name__ == "__main__":
415
  import uvicorn
416
  uvicorn.run(app, host="0.0.0.0", port=7860)