m1k3wn commited on
Commit
4347c84
·
verified ·
1 Parent(s): 3e742c6

Update app.py

Browse files

adds type safe and improved debugging

Files changed (1) hide show
  1. app.py +35 -33
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import logging
5
  from typing import Optional, Dict, Any
@@ -30,9 +30,23 @@ loaded_tokenizers = {}
30
  # Pydantic models for request/response validation
31
  class PredictionRequest(BaseModel):
32
  inputs: str
33
- model: str = "nidra-v1" # Default to v1
34
  parameters: Optional[Dict[str, Any]] = {}
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  class PredictionResponse(BaseModel):
37
  generated_text: str
38
 
@@ -43,21 +57,20 @@ def load_model(model_name: str):
43
  try:
44
  model_path = MODELS[model_name]
45
 
46
- # Load tokenizer with minimal settings
47
  tokenizer = AutoTokenizer.from_pretrained(
48
  model_path,
49
  token=HF_TOKEN,
50
- use_fast=False # Use slower but more stable tokenizer
51
  )
52
 
53
- # Load model with minimal settings
54
  model = AutoModelForSeq2SeqLM.from_pretrained(
55
  model_path,
56
  token=HF_TOKEN,
57
- torch_dtype=torch.float32, # Use standard precision
58
  )
59
 
60
- # Move model to CPU explicitly
61
  model = model.cpu()
62
 
63
  loaded_models[model_name] = model
@@ -68,49 +81,34 @@ def load_model(model_name: str):
68
  raise
69
  return loaded_tokenizers[model_name], loaded_models[model_name]
70
 
71
- @app.get("/")
72
- def read_root():
73
- """Root endpoint with API info"""
74
- return {
75
- "api_name": "Dream Interpretation API",
76
- "models_available": list(MODELS.keys()),
77
- "endpoints": {
78
- "/predict": "POST - Make predictions",
79
- "/health": "GET - Health check"
80
- }
81
- }
82
-
83
- @app.get("/health")
84
- def health_check():
85
- """Basic health check endpoint"""
86
- return {"status": "healthy"}
87
-
88
  @app.post("/predict", response_model=PredictionResponse)
89
  async def predict(request: PredictionRequest):
90
  """Make a prediction using the specified model"""
91
  try:
92
- if request.model not in MODELS:
93
- raise HTTPException(
94
- status_code=400,
95
- detail=f"Invalid model choice. Available models: {list(MODELS.keys())}"
96
- )
97
-
98
  # Load model on demand
99
  tokenizer, model = load_model(request.model)
100
 
 
 
 
101
  # Prepend the shared prefix
102
  full_input = "Interpret this dream: " + request.inputs
 
103
 
104
- # Tokenize and generate with explicit error handling
105
  try:
106
- input_ids = tokenizer(
 
107
  full_input,
108
  return_tensors="pt",
109
  padding=True,
110
  truncation=True,
111
  max_length=512
112
- ).input_ids
 
113
 
 
 
 
114
  outputs = model.generate(
115
  input_ids,
116
  max_length=200,
@@ -118,8 +116,12 @@ async def predict(request: PredictionRequest):
118
  no_repeat_ngram_size=2,
119
  **request.parameters
120
  )
 
121
 
 
122
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
123
  except Exception as e:
124
  logger.error(f"Error in model prediction pipeline: {str(e)}")
125
  raise HTTPException(status_code=500, detail=f"Model prediction failed: {str(e)}")
 
1
  from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel, validator
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import logging
5
  from typing import Optional, Dict, Any
 
30
  # Pydantic models for request/response validation
31
  class PredictionRequest(BaseModel):
32
  inputs: str
33
+ model: str = "nidra-v1"
34
  parameters: Optional[Dict[str, Any]] = {}
35
 
36
+ @validator('inputs')
37
+ def validate_inputs(cls, v):
38
+ if not isinstance(v, str):
39
+ raise ValueError('inputs must be a string')
40
+ if not v.strip():
41
+ raise ValueError('inputs cannot be empty')
42
+ return v.strip()
43
+
44
+ @validator('model')
45
+ def validate_model(cls, v):
46
+ if v not in MODELS:
47
+ raise ValueError(f'model must be one of: {", ".join(MODELS.keys())}')
48
+ return v
49
+
50
  class PredictionResponse(BaseModel):
51
  generated_text: str
52
 
 
57
  try:
58
  model_path = MODELS[model_name]
59
 
60
+ logger.info("Loading tokenizer...")
61
  tokenizer = AutoTokenizer.from_pretrained(
62
  model_path,
63
  token=HF_TOKEN,
64
+ use_fast=False
65
  )
66
 
67
+ logger.info("Loading model...")
68
  model = AutoModelForSeq2SeqLM.from_pretrained(
69
  model_path,
70
  token=HF_TOKEN,
71
+ torch_dtype=torch.float32,
72
  )
73
 
 
74
  model = model.cpu()
75
 
76
  loaded_models[model_name] = model
 
81
  raise
82
  return loaded_tokenizers[model_name], loaded_models[model_name]
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  @app.post("/predict", response_model=PredictionResponse)
85
  async def predict(request: PredictionRequest):
86
  """Make a prediction using the specified model"""
87
  try:
 
 
 
 
 
 
88
  # Load model on demand
89
  tokenizer, model = load_model(request.model)
90
 
91
+ # Log the input for debugging
92
+ logger.info(f"Processing input: {request.inputs}")
93
+
94
  # Prepend the shared prefix
95
  full_input = "Interpret this dream: " + request.inputs
96
+ logger.info(f"Full input: {full_input}")
97
 
 
98
  try:
99
+ # Tokenize
100
+ tokenizer_output = tokenizer(
101
  full_input,
102
  return_tensors="pt",
103
  padding=True,
104
  truncation=True,
105
  max_length=512
106
+ )
107
+ logger.info("Tokenization successful")
108
 
109
+ input_ids = tokenizer_output.input_ids
110
+
111
+ # Generate
112
  outputs = model.generate(
113
  input_ids,
114
  max_length=200,
 
116
  no_repeat_ngram_size=2,
117
  **request.parameters
118
  )
119
+ logger.info("Generation successful")
120
 
121
+ # Decode
122
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
123
+ logger.info(f"Decoded output: {decoded}")
124
+
125
  except Exception as e:
126
  logger.error(f"Error in model prediction pipeline: {str(e)}")
127
  raise HTTPException(status_code=500, detail=f"Model prediction failed: {str(e)}")