m1k3wn commited on
Commit
16c0a32
·
verified ·
1 Parent(s): 67011e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -162
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
@@ -6,58 +7,14 @@ import logging
6
  import os
7
  import sys
8
  import traceback
 
9
  from functools import lru_cache
10
 
11
- # Initialize FastAPI
12
- app = FastAPI()
13
-
14
- # Set up logging with more detailed formatting
15
- logging.basicConfig(
16
- level=logging.DEBUG,
17
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
18
- )
19
- logger = logging.getLogger(__name__)
20
-
21
- # Get HF token
22
- HF_TOKEN = os.environ.get("HF_TOKEN")
23
- if not HF_TOKEN:
24
- logger.warning("No HF_TOKEN found in environment variables")
25
-
26
- MODELS = {
27
- "nidra-v1": "m1k3wn/nidra-v1",
28
- "nidra-v2": "m1k3wn/nidra-v2"
29
- }
30
-
31
- DEFAULT_GENERATION_CONFIGS = {
32
- "nidra-v1": {
33
- "max_length": 300,
34
- "min_length": 150,
35
- "num_beams": 8,
36
- "temperature": 0.55,
37
- "do_sample": True,
38
- "top_p": 0.95,
39
- "repetition_penalty": 4.5,
40
- "no_repeat_ngram_size": 4,
41
- "early_stopping": True,
42
- "length_penalty": 1.2,
43
- },
44
- "nidra-v2": {
45
- "max_length": 300,
46
- "min_length": 150,
47
- "num_beams": 8,
48
- "temperature": 0.4,
49
- "do_sample": True,
50
- "top_p": 0.95,
51
- "repetition_penalty": 3.5,
52
- "no_repeat_ngram_size": 4,
53
- "early_stopping": True,
54
- "length_penalty": 1.2,
55
- }
56
- }
57
 
58
  class ModelManager:
59
  _instances: ClassVar[Dict[str, tuple]] = {}
60
-
61
  @classmethod
62
  def get_model_and_tokenizer(cls, model_name: str):
63
  if model_name not in cls._instances:
@@ -72,24 +29,13 @@ class ModelManager:
72
  )
73
 
74
  logger.info(f"Loading model {model_name}")
75
- # Check if accelerate is available
76
- try:
77
- import accelerate
78
- logger.info("Accelerate package found, using device_map='auto'")
79
- model = T5ForConditionalGeneration.from_pretrained(
80
- model_path,
81
- token=HF_TOKEN,
82
- local_files_only=False,
83
- device_map="auto"
84
- )
85
- except ImportError:
86
- logger.warning("Accelerate package not found, falling back to CPU")
87
- model = T5ForConditionalGeneration.from_pretrained(
88
- model_path,
89
- token=HF_TOKEN,
90
- local_files_only=False
91
- )
92
- model = model.cpu()
93
 
94
  cls._instances[model_name] = (model, tokenizer)
95
  logger.info(f"Successfully loaded {model_name}")
@@ -99,106 +45,18 @@ class ModelManager:
99
  status_code=500,
100
  detail=f"Failed to load model {model_name}: {str(e)}"
101
  )
102
-
103
  return cls._instances[model_name]
104
 
105
- class PredictionRequest(BaseModel):
106
- inputs: str
107
- model: str = "nidra-v1"
108
- parameters: Optional[Dict[str, Any]] = None
109
 
110
- class PredictionResponse(BaseModel):
111
- generated_text: str
112
- selected_model: str # Changed from model_used to avoid namespace conflict
113
-
114
- @app.get("/version")
115
- async def version():
116
  return {
117
- "python_version": sys.version,
118
- "models_available": list(MODELS.keys())
 
119
  }
120
 
121
- @app.get("/health")
122
- async def health():
123
- # More comprehensive health check
124
- try:
125
- # Try to load at least one model to verify functionality
126
- ModelManager.get_model_and_tokenizer("nidra-v1")
127
- return {
128
- "status": "healthy",
129
- "loaded_models": list(ModelManager._instances.keys())
130
- }
131
- except Exception as e:
132
- logger.error(f"Health check failed: {str(e)}")
133
- return {
134
- "status": "unhealthy",
135
- "error": str(e)
136
- }
137
-
138
- @app.post("/predict", response_model=PredictionResponse)
139
- async def predict(request: PredictionRequest):
140
- try:
141
- # Validate model
142
- if request.model not in MODELS:
143
- raise HTTPException(
144
- status_code=400,
145
- detail=f"Invalid model. Available models: {list(MODELS.keys())}"
146
- )
147
-
148
- # Get cached model and tokenizer
149
- model, tokenizer = ModelManager.get_model_and_tokenizer(request.model)
150
-
151
- # Get generation parameters
152
- generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
153
-
154
- # Try to load model's saved generation config
155
- try:
156
- model_generation_config = model.generation_config
157
- generation_params.update({
158
- k: v for k, v in model_generation_config.to_dict().items()
159
- if v is not None
160
- })
161
- except Exception as config_load_error:
162
- logger.warning(f"Using default generation config: {config_load_error}")
163
-
164
- # Override with request-specific parameters
165
- if request.parameters:
166
- generation_params.update(request.parameters)
167
-
168
- logger.debug(f"Final generation parameters: {generation_params}")
169
-
170
- # Prepare input
171
- full_input = "Interpret this dream: " + request.inputs
172
- inputs = tokenizer(
173
- full_input,
174
- return_tensors="pt",
175
- truncation=True,
176
- max_length=512,
177
- padding=True
178
- ).to(model.device) # Ensure inputs are on same device as model
179
-
180
- # Generate
181
- outputs = model.generate(
182
- **inputs,
183
- **{k: v for k, v in generation_params.items() if k in [
184
- 'max_length', 'min_length', 'do_sample', 'temperature',
185
- 'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
186
- 'repetition_penalty', 'early_stopping'
187
- ]}
188
- )
189
-
190
- result = tokenizer.decode(outputs[0], skip_special_tokens=True)
191
-
192
- return PredictionResponse(
193
- generated_text=result,
194
- selected_model=request.model
195
- )
196
-
197
- except Exception as e:
198
- error_msg = f"Error during prediction: {str(e)}\n{traceback.format_exc()}"
199
- logger.error(error_msg)
200
- raise HTTPException(status_code=500, detail=error_msg)
201
-
202
- if __name__ == "__main__":
203
- import uvicorn
204
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import torch
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
 
7
  import os
8
  import sys
9
  import traceback
10
+ import psutil
11
  from functools import lru_cache
12
 
13
+ [... rest of your existing code until ModelManager class ...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  class ModelManager:
16
  _instances: ClassVar[Dict[str, tuple]] = {}
17
+
18
  @classmethod
19
  def get_model_and_tokenizer(cls, model_name: str):
20
  if model_name not in cls._instances:
 
29
  )
30
 
31
  logger.info(f"Loading model {model_name}")
32
+ model = T5ForConditionalGeneration.from_pretrained(
33
+ model_path,
34
+ token=HF_TOKEN,
35
+ local_files_only=False,
36
+ low_cpu_mem_usage=True,
37
+ torch_dtype=torch.float32
38
+ ).cpu()
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  cls._instances[model_name] = (model, tokenizer)
41
  logger.info(f"Successfully loaded {model_name}")
 
45
  status_code=500,
46
  detail=f"Failed to load model {model_name}: {str(e)}"
47
  )
 
48
  return cls._instances[model_name]
49
 
50
+ [... rest of your existing code until before @app.get("/version") ...]
 
 
 
51
 
52
+ @app.get("/debug/memory")
53
+ async def memory_usage():
54
+ process = psutil.Process()
55
+ memory_info = process.memory_info()
 
 
56
  return {
57
+ "memory_used_mb": memory_info.rss / 1024 / 1024,
58
+ "memory_percent": process.memory_percent(),
59
+ "cpu_percent": process.cpu_percent()
60
  }
61
 
62
+ [... rest of your existing code ...]