m1k3wn commited on
Commit
f7ed1d0
·
verified ·
1 Parent(s): fab4412

Update app.py

Browse files

adds debugging error handling

Files changed (1) hide show
  1. app.py +100 -61
app.py CHANGED
@@ -1,28 +1,33 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
4
- from typing import Optional, Dict, Any
5
  import logging
6
  import os
7
  import sys
8
  import traceback
 
9
 
10
- # Initialize FastAPI first
11
  app = FastAPI()
12
 
13
- # Set up logging
14
- logging.basicConfig(level=logging.DEBUG)
 
 
 
15
  logger = logging.getLogger(__name__)
16
 
17
  # Get HF token
18
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
19
 
20
  MODELS = {
21
  "nidra-v1": "m1k3wn/nidra-v1",
22
  "nidra-v2": "m1k3wn/nidra-v2"
23
  }
24
 
25
- # Define default generation configurations for each model
26
  DEFAULT_GENERATION_CONFIGS = {
27
  "nidra-v1": {
28
  "max_length": 300,
@@ -49,105 +54,139 @@ DEFAULT_GENERATION_CONFIGS = {
49
  "length_penalty": 1.2,
50
  }
51
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  class PredictionRequest(BaseModel):
53
  inputs: str
54
  model: str = "nidra-v1"
55
- parameters: Optional[Dict[str, Any]] = None # Allow custom parameters
56
 
57
  class PredictionResponse(BaseModel):
58
  generated_text: str
 
59
 
60
  @app.get("/version")
61
  async def version():
62
- return {"python_version": sys.version}
 
 
 
63
 
64
  @app.get("/health")
65
  async def health():
66
- return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  @app.post("/predict", response_model=PredictionResponse)
69
  async def predict(request: PredictionRequest):
70
  try:
71
- # Validate model
72
  if request.model not in MODELS:
73
- raise HTTPException(status_code=400, detail=f"Invalid model: {request.model}")
74
-
75
- logger.info(f"Loading model: {request.model}")
76
- model_path = MODELS[request.model]
77
-
78
- # Add debug logging
79
- logger.info("Attempting to load tokenizer...")
80
- tokenizer = T5Tokenizer.from_pretrained(
81
- model_path,
82
- token=HF_TOKEN,
83
- local_files_only=False,
84
- return_special_tokens_mask=True
85
- )
86
- logger.info("Tokenizer loaded successfully")
87
-
88
- logger.info("Attempting to load model...")
89
- model = T5ForConditionalGeneration.from_pretrained(
90
- model_path,
91
- token=HF_TOKEN,
92
- local_files_only=False
93
- )
94
- logger.info("Model loaded successfully")
95
 
96
- # Priority: 1. Request parameters, 2. Model's saved generation_config, 3. Default configs
97
  generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
98
 
99
  # Try to load model's saved generation config
100
  try:
101
- model_generation_config = GenerationConfig.from_pretrained(model_path)
102
- # Convert to dict to merge with default configs
103
  generation_params.update({
104
- k: v for k, v in model_generation_config.to_dict().items()
105
  if v is not None
106
  })
107
  except Exception as config_load_error:
108
- logger.warning(f"Could not load model's generation config: {config_load_error}")
109
 
110
- # Override with request-specific parameters if provided
111
  if request.parameters:
112
  generation_params.update(request.parameters)
113
 
114
- logger.info(f"Final Generation Parameters: {generation_params}")
115
 
116
-
117
  full_input = "Interpret this dream: " + request.inputs
118
- logger.info(f"Processing input: {full_input}")
119
-
120
- logger.info("Tokenizing input...")
121
  inputs = tokenizer(
122
  full_input,
123
  return_tensors="pt",
124
  truncation=True,
125
  max_length=512,
126
  padding=True
127
- )
128
- logger.info("Input tokenized successfully")
129
-
130
- logger.info("Generating output...")
131
-
132
- # Generate with final parameters
133
  outputs = model.generate(
134
- **inputs,
135
  **{k: v for k, v in generation_params.items() if k in [
136
- 'max_length', 'min_length', 'do_sample', 'temperature',
137
- 'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
138
  'repetition_penalty', 'early_stopping'
139
  ]}
140
  )
141
- logger.info("Output generated successfully")
142
-
143
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
144
- logger.info(f"Final result: {result}")
145
-
146
- return PredictionResponse(generated_text=result)
147
 
 
 
 
 
 
148
  except Exception as e:
149
- logger.error(f"Error: {str(e)}")
150
- logger.error(f"Error type: {type(e)}")
151
- import traceback
152
- logger.error(f"Traceback: {traceback.format_exc()}")
153
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
4
+ from typing import Optional, Dict, Any, ClassVar
5
  import logging
6
  import os
7
  import sys
8
  import traceback
9
+ from functools import lru_cache
10
 
11
+ # Initialize FastAPI
12
  app = FastAPI()
13
 
14
+ # Set up logging with more detailed formatting
15
+ logging.basicConfig(
16
+ level=logging.DEBUG,
17
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
18
+ )
19
  logger = logging.getLogger(__name__)
20
 
21
  # Get HF token
22
  HF_TOKEN = os.environ.get("HF_TOKEN")
23
+ if not HF_TOKEN:
24
+ logger.warning("No HF_TOKEN found in environment variables")
25
 
26
  MODELS = {
27
  "nidra-v1": "m1k3wn/nidra-v1",
28
  "nidra-v2": "m1k3wn/nidra-v2"
29
  }
30
 
 
31
  DEFAULT_GENERATION_CONFIGS = {
32
  "nidra-v1": {
33
  "max_length": 300,
 
54
  "length_penalty": 1.2,
55
  }
56
  }
57
+
58
+ class ModelManager:
59
+ _instances: ClassVar[Dict[str, tuple]] = {}
60
+
61
+ @classmethod
62
+ def get_model_and_tokenizer(cls, model_name: str):
63
+ if model_name not in cls._instances:
64
+ try:
65
+ model_path = MODELS[model_name]
66
+ logger.info(f"Loading tokenizer for {model_name}")
67
+ tokenizer = T5Tokenizer.from_pretrained(
68
+ model_path,
69
+ token=HF_TOKEN,
70
+ local_files_only=False,
71
+ return_special_tokens_mask=True
72
+ )
73
+
74
+ logger.info(f"Loading model {model_name}")
75
+ model = T5ForConditionalGeneration.from_pretrained(
76
+ model_path,
77
+ token=HF_TOKEN,
78
+ local_files_only=False,
79
+ device_map="auto" # This will handle GPU if available
80
+ )
81
+
82
+ cls._instances[model_name] = (model, tokenizer)
83
+ logger.info(f"Successfully loaded {model_name}")
84
+ except Exception as e:
85
+ logger.error(f"Error loading {model_name}: {str(e)}")
86
+ raise HTTPException(
87
+ status_code=500,
88
+ detail=f"Failed to load model {model_name}: {str(e)}"
89
+ )
90
+
91
+ return cls._instances[model_name]
92
+
93
  class PredictionRequest(BaseModel):
94
  inputs: str
95
  model: str = "nidra-v1"
96
+ parameters: Optional[Dict[str, Any]] = None
97
 
98
  class PredictionResponse(BaseModel):
99
  generated_text: str
100
+ model_used: str
101
 
102
  @app.get("/version")
103
  async def version():
104
+ return {
105
+ "python_version": sys.version,
106
+ "models_available": list(MODELS.keys())
107
+ }
108
 
109
  @app.get("/health")
110
  async def health():
111
+ # More comprehensive health check
112
+ try:
113
+ # Try to load at least one model to verify functionality
114
+ ModelManager.get_model_and_tokenizer("nidra-v1")
115
+ return {
116
+ "status": "healthy",
117
+ "loaded_models": list(ModelManager._instances.keys())
118
+ }
119
+ except Exception as e:
120
+ logger.error(f"Health check failed: {str(e)}")
121
+ return {
122
+ "status": "unhealthy",
123
+ "error": str(e)
124
+ }
125
 
126
  @app.post("/predict", response_model=PredictionResponse)
127
  async def predict(request: PredictionRequest):
128
  try:
129
+ # Validate model
130
  if request.model not in MODELS:
131
+ raise HTTPException(
132
+ status_code=400,
133
+ detail=f"Invalid model. Available models: {list(MODELS.keys())}"
134
+ )
135
+
136
+ # Get cached model and tokenizer
137
+ model, tokenizer = ModelManager.get_model_and_tokenizer(request.model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ # Get generation parameters
140
  generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
141
 
142
  # Try to load model's saved generation config
143
  try:
144
+ model_generation_config = model.generation_config
 
145
  generation_params.update({
146
+ k: v for k, v in model_generation_config.to_dict().items()
147
  if v is not None
148
  })
149
  except Exception as config_load_error:
150
+ logger.warning(f"Using default generation config: {config_load_error}")
151
 
152
+ # Override with request-specific parameters
153
  if request.parameters:
154
  generation_params.update(request.parameters)
155
 
156
+ logger.debug(f"Final generation parameters: {generation_params}")
157
 
158
+ # Prepare input
159
  full_input = "Interpret this dream: " + request.inputs
 
 
 
160
  inputs = tokenizer(
161
  full_input,
162
  return_tensors="pt",
163
  truncation=True,
164
  max_length=512,
165
  padding=True
166
+ ).to(model.device) # Ensure inputs are on same device as model
167
+
168
+ # Generate
 
 
 
169
  outputs = model.generate(
170
+ **inputs,
171
  **{k: v for k, v in generation_params.items() if k in [
172
+ 'max_length', 'min_length', 'do_sample', 'temperature',
173
+ 'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
174
  'repetition_penalty', 'early_stopping'
175
  ]}
176
  )
177
+
 
178
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
179
 
180
+ return PredictionResponse(
181
+ generated_text=result,
182
+ model_used=request.model
183
+ )
184
+
185
  except Exception as e:
186
+ error_msg = f"Error during prediction: {str(e)}\n{traceback.format_exc()}"
187
+ logger.error(error_msg)
188
+ raise HTTPException(status_code=500, detail=error_msg)
189
+
190
+ if __name__ == "__main__":
191
+ import uvicorn
192
+ uvicorn.run(app, host="0.0.0.0", port=7860)