File size: 4,162 Bytes
2095fff
 
 
 
 
7394c77
e3e12f1
2095fff
 
 
 
 
 
 
7394c77
 
 
 
 
2095fff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3e12f1
 
 
 
 
 
 
 
 
2095fff
 
7394c77
e3e12f1
2095fff
e3e12f1
 
 
 
2095fff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3e12f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2095fff
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import logging
from typing import Optional, Dict, Any
import os
import torch

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="Dream Interpretation API")

# Get HF token from environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
    raise ValueError("HF_TOKEN environment variable must be set")

# Define the model names
MODELS = {
    "nidra-v1": "m1k3wn/nidra-v1",
    "nidra-v2": "m1k3wn/nidra-v2"
}

# Cache for loaded models
loaded_models = {}
loaded_tokenizers = {}

# Pydantic models for request/response validation
class PredictionRequest(BaseModel):
    inputs: str
    model: str = "nidra-v1"  # Default to v1
    parameters: Optional[Dict[str, Any]] = {}

class PredictionResponse(BaseModel):
    generated_text: str

def load_model(model_name: str):
    """Load model and tokenizer on demand"""
    if model_name not in loaded_models:
        logger.info(f"Loading {model_name}...")
        try:
            model_path = MODELS[model_name]
            
            # Load tokenizer with minimal settings
            tokenizer = AutoTokenizer.from_pretrained(
                model_path, 
                token=HF_TOKEN,
                use_fast=False  # Use slower but more stable tokenizer
            )
            
            # Load model with minimal settings
            model = AutoModelForSeq2SeqLM.from_pretrained(
                model_path,
                token=HF_TOKEN,
                torch_dtype=torch.float32,  # Use standard precision
            )
            
            # Move model to CPU explicitly
            model = model.cpu()
            
            loaded_models[model_name] = model
            loaded_tokenizers[model_name] = tokenizer
            logger.info(f"Successfully loaded {model_name}")
        except Exception as e:
            logger.error(f"Error loading {model_name}: {str(e)}")
            raise
    return loaded_tokenizers[model_name], loaded_models[model_name]

@app.get("/")
def read_root():
    """Root endpoint with API info"""
    return {
        "api_name": "Dream Interpretation API",
        "models_available": list(MODELS.keys()),
        "endpoints": {
            "/predict": "POST - Make predictions",
            "/health": "GET - Health check"
        }
    }

@app.get("/health")
def health_check():
    """Basic health check endpoint"""
    return {"status": "healthy"}

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    """Make a prediction using the specified model"""
    try:
        if request.model not in MODELS:
            raise HTTPException(
                status_code=400, 
                detail=f"Invalid model choice. Available models: {list(MODELS.keys())}"
            )

        # Load model on demand
        tokenizer, model = load_model(request.model)

        # Prepend the shared prefix
        full_input = "Interpret this dream: " + request.inputs

        # Tokenize and generate with explicit error handling
        try:
            input_ids = tokenizer(
                full_input, 
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            ).input_ids
            
            outputs = model.generate(
                input_ids,
                max_length=200,
                num_return_sequences=1,
                no_repeat_ngram_size=2,
                **request.parameters
            )
            
            decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
        except Exception as e:
            logger.error(f"Error in model prediction pipeline: {str(e)}")
            raise HTTPException(status_code=500, detail=f"Model prediction failed: {str(e)}")

        return PredictionResponse(generated_text=decoded)

    except Exception as e:
        logger.error(f"Error in prediction: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))