m1k3wn commited on
Commit
2095fff
·
verified ·
1 Parent(s): 330b757

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+ import logging
5
+ from typing import Optional, Dict, Any
6
+
7
+ # Set up logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ app = FastAPI(title="Dream Interpretation API")
12
+
13
+ # Define the model names
14
+ MODELS = {
15
+ "nidra-v1": "m1k3wn/nidra-v1",
16
+ "nidra-v2": "m1k3wn/nidra-v2"
17
+ }
18
+
19
+ # Cache for loaded models
20
+ loaded_models = {}
21
+ loaded_tokenizers = {}
22
+
23
+ # Pydantic models for request/response validation
24
+ class PredictionRequest(BaseModel):
25
+ inputs: str
26
+ model: str = "nidra-v1" # Default to v1
27
+ parameters: Optional[Dict[str, Any]] = {}
28
+
29
+ class PredictionResponse(BaseModel):
30
+ generated_text: str
31
+
32
+ def load_model(model_name: str):
33
+ """Load model and tokenizer on demand"""
34
+ if model_name not in loaded_models:
35
+ logger.info(f"Loading {model_name}...")
36
+ try:
37
+ model_path = MODELS[model_name]
38
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
39
+ model = AutoModelForSeq2SeqLM.from_pretrained(
40
+ model_path,
41
+ device_map="auto",
42
+ torch_dtype="auto"
43
+ )
44
+ loaded_models[model_name] = model
45
+ loaded_tokenizers[model_name] = tokenizer
46
+ logger.info(f"Successfully loaded {model_name}")
47
+ except Exception as e:
48
+ logger.error(f"Error loading {model_name}: {str(e)}")
49
+ raise
50
+ return loaded_tokenizers[model_name], loaded_models[model_name]
51
+
52
+ @app.get("/")
53
+ def read_root():
54
+ """Root endpoint with API info"""
55
+ return {
56
+ "api_name": "Dream Interpretation API",
57
+ "models_available": list(MODELS.keys()),
58
+ "endpoints": {
59
+ "/predict": "POST - Make predictions",
60
+ "/health": "GET - Health check"
61
+ }
62
+ }
63
+
64
+ @app.get("/health")
65
+ def health_check():
66
+ """Basic health check endpoint"""
67
+ return {"status": "healthy"}
68
+
69
+ @app.post("/predict", response_model=PredictionResponse)
70
+ async def predict(request: PredictionRequest):
71
+ """Make a prediction using the specified model"""
72
+ try:
73
+ if request.model not in MODELS:
74
+ raise HTTPException(
75
+ status_code=400,
76
+ detail=f"Invalid model choice. Available models: {list(MODELS.keys())}"
77
+ )
78
+
79
+ # Load model on demand
80
+ tokenizer, model = load_model(request.model)
81
+
82
+ # Prepend the shared prefix
83
+ full_input = "Interpret this dream: " + request.inputs
84
+
85
+ # Tokenize and generate
86
+ input_ids = tokenizer(full_input, return_tensors="pt").input_ids
87
+ outputs = model.generate(input_ids, **request.parameters)
88
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
89
+
90
+ return PredictionResponse(generated_text=decoded)
91
+
92
+ except Exception as e:
93
+ logger.error(f"Error in prediction: {str(e)}")
94
+ raise HTTPException(status_code=500, detail=str(e))