m1k3wn commited on
Commit
78a09b4
·
verified ·
1 Parent(s): 4347c84

Update app.py

Browse files

simplifies for debugging

Files changed (1) hide show
  1. app.py +25 -101
app.py CHANGED
@@ -1,8 +1,7 @@
1
  from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel, validator
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import logging
5
- from typing import Optional, Dict, Any
6
  import os
7
  import torch
8
 
@@ -10,124 +9,49 @@ import torch
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
- app = FastAPI(title="Dream Interpretation API")
 
14
 
15
- # Get HF token from environment variable
16
  HF_TOKEN = os.environ.get("HF_TOKEN")
17
- if not HF_TOKEN:
18
- raise ValueError("HF_TOKEN environment variable must be set")
19
 
20
- # Define the model names
21
  MODELS = {
22
  "nidra-v1": "m1k3wn/nidra-v1",
23
  "nidra-v2": "m1k3wn/nidra-v2"
24
  }
25
 
26
- # Cache for loaded models
27
- loaded_models = {}
28
- loaded_tokenizers = {}
29
-
30
- # Pydantic models for request/response validation
31
  class PredictionRequest(BaseModel):
32
  inputs: str
33
  model: str = "nidra-v1"
34
- parameters: Optional[Dict[str, Any]] = {}
35
-
36
- @validator('inputs')
37
- def validate_inputs(cls, v):
38
- if not isinstance(v, str):
39
- raise ValueError('inputs must be a string')
40
- if not v.strip():
41
- raise ValueError('inputs cannot be empty')
42
- return v.strip()
43
-
44
- @validator('model')
45
- def validate_model(cls, v):
46
- if v not in MODELS:
47
- raise ValueError(f'model must be one of: {", ".join(MODELS.keys())}')
48
- return v
49
 
 
50
  class PredictionResponse(BaseModel):
51
  generated_text: str
52
 
53
- def load_model(model_name: str):
54
- """Load model and tokenizer on demand"""
55
- if model_name not in loaded_models:
56
- logger.info(f"Loading {model_name}...")
57
- try:
58
- model_path = MODELS[model_name]
59
-
60
- logger.info("Loading tokenizer...")
61
- tokenizer = AutoTokenizer.from_pretrained(
62
- model_path,
63
- token=HF_TOKEN,
64
- use_fast=False
65
- )
66
-
67
- logger.info("Loading model...")
68
- model = AutoModelForSeq2SeqLM.from_pretrained(
69
- model_path,
70
- token=HF_TOKEN,
71
- torch_dtype=torch.float32,
72
- )
73
-
74
- model = model.cpu()
75
-
76
- loaded_models[model_name] = model
77
- loaded_tokenizers[model_name] = tokenizer
78
- logger.info(f"Successfully loaded {model_name}")
79
- except Exception as e:
80
- logger.error(f"Error loading {model_name}: {str(e)}")
81
- raise
82
- return loaded_tokenizers[model_name], loaded_models[model_name]
83
 
84
  @app.post("/predict", response_model=PredictionResponse)
85
  async def predict(request: PredictionRequest):
86
- """Make a prediction using the specified model"""
87
  try:
88
- # Load model on demand
89
- tokenizer, model = load_model(request.model)
90
-
91
- # Log the input for debugging
92
- logger.info(f"Processing input: {request.inputs}")
93
-
94
- # Prepend the shared prefix
95
  full_input = "Interpret this dream: " + request.inputs
96
- logger.info(f"Full input: {full_input}")
97
-
98
- try:
99
- # Tokenize
100
- tokenizer_output = tokenizer(
101
- full_input,
102
- return_tensors="pt",
103
- padding=True,
104
- truncation=True,
105
- max_length=512
106
- )
107
- logger.info("Tokenization successful")
108
-
109
- input_ids = tokenizer_output.input_ids
110
-
111
- # Generate
112
- outputs = model.generate(
113
- input_ids,
114
- max_length=200,
115
- num_return_sequences=1,
116
- no_repeat_ngram_size=2,
117
- **request.parameters
118
- )
119
- logger.info("Generation successful")
120
-
121
- # Decode
122
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
123
- logger.info(f"Decoded output: {decoded}")
124
-
125
- except Exception as e:
126
- logger.error(f"Error in model prediction pipeline: {str(e)}")
127
- raise HTTPException(status_code=500, detail=f"Model prediction failed: {str(e)}")
128
-
129
- return PredictionResponse(generated_text=decoded)
130
-
131
  except Exception as e:
132
- logger.error(f"Error in prediction: {str(e)}")
133
  raise HTTPException(status_code=500, detail=str(e))
 
1
  from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import logging
 
5
  import os
6
  import torch
7
 
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
+ # Initialize FastAPI
13
+ app = FastAPI()
14
 
15
+ # Get HF token
16
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
17
 
18
+ # Define models
19
  MODELS = {
20
  "nidra-v1": "m1k3wn/nidra-v1",
21
  "nidra-v2": "m1k3wn/nidra-v2"
22
  }
23
 
24
+ # Simple request model
 
 
 
 
25
  class PredictionRequest(BaseModel):
26
  inputs: str
27
  model: str = "nidra-v1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # Simple response model
30
  class PredictionResponse(BaseModel):
31
  generated_text: str
32
 
33
+ @app.get("/")
34
+ async def root():
35
+ return {"message": "Dream Interpretation API", "status": "running"}
36
+
37
+ @app.get("/health")
38
+ async def health():
39
+ return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  @app.post("/predict", response_model=PredictionResponse)
42
  async def predict(request: PredictionRequest):
 
43
  try:
44
+ # Load model
45
+ model_path = MODELS[request.model]
46
+ tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
47
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path, token=HF_TOKEN)
48
+
49
+ # Process input
 
50
  full_input = "Interpret this dream: " + request.inputs
51
+ inputs = tokenizer(full_input, return_tensors="pt")
52
+ outputs = model.generate(**inputs)
53
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
54
+
55
+ return PredictionResponse(generated_text=result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  except Exception as e:
 
57
  raise HTTPException(status_code=500, detail=str(e))