Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
+
import logging
|
5 |
+
from typing import Optional, Dict, Any
|
6 |
+
|
7 |
+
# Set up logging
|
8 |
+
logging.basicConfig(level=logging.INFO)
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
app = FastAPI(title="Dream Interpretation API")
|
12 |
+
|
13 |
+
# Define the model names
|
14 |
+
MODELS = {
|
15 |
+
"nidra-v1": "m1k3wn/nidra-v1",
|
16 |
+
"nidra-v2": "m1k3wn/nidra-v2"
|
17 |
+
}
|
18 |
+
|
19 |
+
# Cache for loaded models
|
20 |
+
loaded_models = {}
|
21 |
+
loaded_tokenizers = {}
|
22 |
+
|
23 |
+
# Pydantic models for request/response validation
|
24 |
+
class PredictionRequest(BaseModel):
|
25 |
+
inputs: str
|
26 |
+
model: str = "nidra-v1" # Default to v1
|
27 |
+
parameters: Optional[Dict[str, Any]] = {}
|
28 |
+
|
29 |
+
class PredictionResponse(BaseModel):
|
30 |
+
generated_text: str
|
31 |
+
|
32 |
+
def load_model(model_name: str):
|
33 |
+
"""Load model and tokenizer on demand"""
|
34 |
+
if model_name not in loaded_models:
|
35 |
+
logger.info(f"Loading {model_name}...")
|
36 |
+
try:
|
37 |
+
model_path = MODELS[model_name]
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
39 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
40 |
+
model_path,
|
41 |
+
device_map="auto",
|
42 |
+
torch_dtype="auto"
|
43 |
+
)
|
44 |
+
loaded_models[model_name] = model
|
45 |
+
loaded_tokenizers[model_name] = tokenizer
|
46 |
+
logger.info(f"Successfully loaded {model_name}")
|
47 |
+
except Exception as e:
|
48 |
+
logger.error(f"Error loading {model_name}: {str(e)}")
|
49 |
+
raise
|
50 |
+
return loaded_tokenizers[model_name], loaded_models[model_name]
|
51 |
+
|
52 |
+
@app.get("/")
|
53 |
+
def read_root():
|
54 |
+
"""Root endpoint with API info"""
|
55 |
+
return {
|
56 |
+
"api_name": "Dream Interpretation API",
|
57 |
+
"models_available": list(MODELS.keys()),
|
58 |
+
"endpoints": {
|
59 |
+
"/predict": "POST - Make predictions",
|
60 |
+
"/health": "GET - Health check"
|
61 |
+
}
|
62 |
+
}
|
63 |
+
|
64 |
+
@app.get("/health")
|
65 |
+
def health_check():
|
66 |
+
"""Basic health check endpoint"""
|
67 |
+
return {"status": "healthy"}
|
68 |
+
|
69 |
+
@app.post("/predict", response_model=PredictionResponse)
|
70 |
+
async def predict(request: PredictionRequest):
|
71 |
+
"""Make a prediction using the specified model"""
|
72 |
+
try:
|
73 |
+
if request.model not in MODELS:
|
74 |
+
raise HTTPException(
|
75 |
+
status_code=400,
|
76 |
+
detail=f"Invalid model choice. Available models: {list(MODELS.keys())}"
|
77 |
+
)
|
78 |
+
|
79 |
+
# Load model on demand
|
80 |
+
tokenizer, model = load_model(request.model)
|
81 |
+
|
82 |
+
# Prepend the shared prefix
|
83 |
+
full_input = "Interpret this dream: " + request.inputs
|
84 |
+
|
85 |
+
# Tokenize and generate
|
86 |
+
input_ids = tokenizer(full_input, return_tensors="pt").input_ids
|
87 |
+
outputs = model.generate(input_ids, **request.parameters)
|
88 |
+
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
89 |
+
|
90 |
+
return PredictionResponse(generated_text=decoded)
|
91 |
+
|
92 |
+
except Exception as e:
|
93 |
+
logger.error(f"Error in prediction: {str(e)}")
|
94 |
+
raise HTTPException(status_code=500, detail=str(e))
|