m1k3wn commited on
Commit
abb0e4d
·
verified ·
1 Parent(s): 4eee935

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -12
app.py CHANGED
@@ -7,14 +7,58 @@ import logging
7
  import os
8
  import sys
9
  import traceback
10
- import psutil
11
  from functools import lru_cache
12
 
13
- [... rest of your existing code until ModelManager class ...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  class ModelManager:
16
  _instances: ClassVar[Dict[str, tuple]] = {}
17
-
18
  @classmethod
19
  def get_model_and_tokenizer(cls, model_name: str):
20
  if model_name not in cls._instances:
@@ -29,13 +73,26 @@ class ModelManager:
29
  )
30
 
31
  logger.info(f"Loading model {model_name}")
32
- model = T5ForConditionalGeneration.from_pretrained(
33
- model_path,
34
- token=HF_TOKEN,
35
- local_files_only=False,
36
- low_cpu_mem_usage=True,
37
- torch_dtype=torch.float32
38
- ).cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  cls._instances[model_name] = (model, tokenizer)
41
  logger.info(f"Successfully loaded {model_name}")
@@ -45,9 +102,17 @@ class ModelManager:
45
  status_code=500,
46
  detail=f"Failed to load model {model_name}: {str(e)}"
47
  )
 
48
  return cls._instances[model_name]
49
 
50
- [... rest of your existing code until before @app.get("/version") ...]
 
 
 
 
 
 
 
51
 
52
  @app.get("/debug/memory")
53
  async def memory_usage():
@@ -59,4 +124,94 @@ async def memory_usage():
59
  "cpu_percent": process.cpu_percent()
60
  }
61
 
62
- [... rest of your existing code ...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import os
8
  import sys
9
  import traceback
 
10
  from functools import lru_cache
11
 
12
+ # Initialize FastAPI
13
+ app = FastAPI()
14
+
15
+ # Set up logging with more detailed formatting
16
+ logging.basicConfig(
17
+ level=logging.DEBUG,
18
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Get HF token
23
+ HF_TOKEN = os.environ.get("HF_TOKEN")
24
+ if not HF_TOKEN:
25
+ logger.warning("No HF_TOKEN found in environment variables")
26
+
27
+ MODELS = {
28
+ "nidra-v1": "m1k3wn/nidra-v1",
29
+ "nidra-v2": "m1k3wn/nidra-v2"
30
+ }
31
+
32
+ DEFAULT_GENERATION_CONFIGS = {
33
+ "nidra-v1": {
34
+ "max_length": 300,
35
+ "min_length": 150,
36
+ "num_beams": 8,
37
+ "temperature": 0.55,
38
+ "do_sample": True,
39
+ "top_p": 0.95,
40
+ "repetition_penalty": 4.5,
41
+ "no_repeat_ngram_size": 4,
42
+ "early_stopping": True,
43
+ "length_penalty": 1.2,
44
+ },
45
+ "nidra-v2": {
46
+ "max_length": 300,
47
+ "min_length": 150,
48
+ "num_beams": 8,
49
+ "temperature": 0.4,
50
+ "do_sample": True,
51
+ "top_p": 0.95,
52
+ "repetition_penalty": 3.5,
53
+ "no_repeat_ngram_size": 4,
54
+ "early_stopping": True,
55
+ "length_penalty": 1.2,
56
+ }
57
+ }
58
 
59
  class ModelManager:
60
  _instances: ClassVar[Dict[str, tuple]] = {}
61
+
62
  @classmethod
63
  def get_model_and_tokenizer(cls, model_name: str):
64
  if model_name not in cls._instances:
 
73
  )
74
 
75
  logger.info(f"Loading model {model_name}")
76
+ # Check if accelerate is available
77
+ try:
78
+ import accelerate
79
+ logger.info("Accelerate package found, using device_map='auto'")
80
+ model = T5ForConditionalGeneration.from_pretrained(
81
+ model_path,
82
+ token=HF_TOKEN,
83
+ local_files_only=False,
84
+ device_map="auto",
85
+ low_cpu_mem_usage=True,
86
+ torch_dtype=torch.float32
87
+ )
88
+ except ImportError:
89
+ logger.warning("Accelerate package not found, falling back to CPU")
90
+ model = T5ForConditionalGeneration.from_pretrained(
91
+ model_path,
92
+ token=HF_TOKEN,
93
+ local_files_only=False
94
+ )
95
+ model = model.cpu()
96
 
97
  cls._instances[model_name] = (model, tokenizer)
98
  logger.info(f"Successfully loaded {model_name}")
 
102
  status_code=500,
103
  detail=f"Failed to load model {model_name}: {str(e)}"
104
  )
105
+
106
  return cls._instances[model_name]
107
 
108
+ class PredictionRequest(BaseModel):
109
+ inputs: str
110
+ model: str = "nidra-v1"
111
+ parameters: Optional[Dict[str, Any]] = None
112
+
113
+ class PredictionResponse(BaseModel):
114
+ generated_text: str
115
+ selected_model: str # Changed from model_used to avoid namespace conflict
116
 
117
  @app.get("/debug/memory")
118
  async def memory_usage():
 
124
  "cpu_percent": process.cpu_percent()
125
  }
126
 
127
+ @app.get("/version")
128
+ async def version():
129
+ return {
130
+ "python_version": sys.version,
131
+ "models_available": list(MODELS.keys())
132
+ }
133
+
134
+ @app.get("/health")
135
+ async def health():
136
+ # More comprehensive health check
137
+ try:
138
+ # Try to load at least one model to verify functionality
139
+ ModelManager.get_model_and_tokenizer("nidra-v1")
140
+ return {
141
+ "status": "healthy",
142
+ "loaded_models": list(ModelManager._instances.keys())
143
+ }
144
+ except Exception as e:
145
+ logger.error(f"Health check failed: {str(e)}")
146
+ return {
147
+ "status": "unhealthy",
148
+ "error": str(e)
149
+ }
150
+
151
+ @app.post("/predict", response_model=PredictionResponse)
152
+ async def predict(request: PredictionRequest):
153
+ try:
154
+ # Validate model
155
+ if request.model not in MODELS:
156
+ raise HTTPException(
157
+ status_code=400,
158
+ detail=f"Invalid model. Available models: {list(MODELS.keys())}"
159
+ )
160
+
161
+ # Get cached model and tokenizer
162
+ model, tokenizer = ModelManager.get_model_and_tokenizer(request.model)
163
+
164
+ # Get generation parameters
165
+ generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
166
+
167
+ # Try to load model's saved generation config
168
+ try:
169
+ model_generation_config = model.generation_config
170
+ generation_params.update({
171
+ k: v for k, v in model_generation_config.to_dict().items()
172
+ if v is not None
173
+ })
174
+ except Exception as config_load_error:
175
+ logger.warning(f"Using default generation config: {config_load_error}")
176
+
177
+ # Override with request-specific parameters
178
+ if request.parameters:
179
+ generation_params.update(request.parameters)
180
+
181
+ logger.debug(f"Final generation parameters: {generation_params}")
182
+
183
+ # Prepare input
184
+ full_input = "Interpret this dream: " + request.inputs
185
+ inputs = tokenizer(
186
+ full_input,
187
+ return_tensors="pt",
188
+ truncation=True,
189
+ max_length=512,
190
+ padding=True
191
+ ).to(model.device) # Ensure inputs are on same device as model
192
+
193
+ # Generate
194
+ outputs = model.generate(
195
+ **inputs,
196
+ **{k: v for k, v in generation_params.items() if k in [
197
+ 'max_length', 'min_length', 'do_sample', 'temperature',
198
+ 'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
199
+ 'repetition_penalty', 'early_stopping'
200
+ ]}
201
+ )
202
+
203
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
204
+
205
+ return PredictionResponse(
206
+ generated_text=result,
207
+ selected_model=request.model
208
+ )
209
+
210
+ except Exception as e:
211
+ error_msg = f"Error during prediction: {str(e)}\n{traceback.format_exc()}"
212
+ logger.error(error_msg)
213
+ raise HTTPException(status_code=500, detail=error_msg)
214
+
215
+ if __name__ == "__main__":
216
+ import uvicorn
217
+ uvicorn.run(app, host="0.0.0.0", port=7860)