File size: 4,901 Bytes
2095fff
78a09b4
5c94eeb
fab4412
2095fff
7394c77
a4b1bdb
fab4412
a4b1bdb
2580a1e
 
2095fff
2580a1e
5fc0c7a
2095fff
 
2580a1e
7394c77
 
2095fff
 
 
 
 
5c94eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2095fff
 
4347c84
5c94eeb
4347c84
2095fff
 
 
2580a1e
 
 
 
 
 
 
 
2095fff
 
 
5c94eeb
 
 
 
5fc0c7a
78a09b4
 
9ab0a9a
 
 
 
 
2580a1e
9ab0a9a
 
 
 
 
 
 
 
2580a1e
9ab0a9a
 
5c94eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fc0c7a
2095fff
10c106d
5fc0c7a
9ab0a9a
5fc0c7a
 
 
 
9ab0a9a
 
5fc0c7a
9ab0a9a
5fc0c7a
9ab0a9a
5c94eeb
 
 
 
 
 
 
 
 
 
9ab0a9a
 
78a09b4
9ab0a9a
10c106d
78a09b4
5fc0c7a
2095fff
5fc0c7a
9ab0a9a
 
 
2580a1e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
from typing import Optional, Dict, Any
import logging
import os
import sys
import traceback

# Initialize FastAPI first
app = FastAPI()

# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

# Get HF token
HF_TOKEN = os.environ.get("HF_TOKEN")

MODELS = {
    "nidra-v1": "m1k3wn/nidra-v1",
    "nidra-v2": "m1k3wn/nidra-v2"
}

# Define default generation configurations for each model
DEFAULT_GENERATION_CONFIGS = {
    "nidra-v1": {
        "max_length": 300,
        "min_length": 150,
        "num_beams": 8,
        "temperature": 0.55,
        "do_sample": True,
        "top_p": 0.95,
        "repetition_penalty": 4.5,
        "no_repeat_ngram_size": 4,
        "early_stopping": True,
        "length_penalty": 1.2,
    },
    "nidra-v2": {
        "max_length": 300,
        "min_length": 150,
        "num_beams": 8,
        "temperature": 0.4,
        "do_sample": True,
        "top_p": 0.95,
        "repetition_penalty": 3.5,
        "no_repeat_ngram_size": 4,
        "early_stopping": True,
        "length_penalty": 1.2,
    }
}
class PredictionRequest(BaseModel):
    inputs: str
    model: str = "nidra-v1"
    parameters: Optional[Dict[str, Any]] = None  # Allow custom parameters

class PredictionResponse(BaseModel):
    generated_text: str

@app.get("/version")
async def version():
    return {"python_version": sys.version}

@app.get("/health")
async def health():
    return {"status": "healthy"}

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    try:
         # Validate model
        if request.model not in MODELS:
            raise HTTPException(status_code=400, detail=f"Invalid model: {request.model}")
            
        logger.info(f"Loading model: {request.model}")
        model_path = MODELS[request.model]
        
        # Add debug logging
        logger.info("Attempting to load tokenizer...")
        tokenizer = T5Tokenizer.from_pretrained(
            model_path,
            token=HF_TOKEN,
            local_files_only=False,
            return_special_tokens_mask=True
        )
        logger.info("Tokenizer loaded successfully")
        
        logger.info("Attempting to load model...")
        model = T5ForConditionalGeneration.from_pretrained(
            model_path,
            token=HF_TOKEN,
            local_files_only=False
        )
        logger.info("Model loaded successfully")

        # Priority: 1. Request parameters, 2. Model's saved generation_config, 3. Default configs
        generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()

        # Try to load model's saved generation config
        try:
            model_generation_config = GenerationConfig.from_pretrained(model_path)
            # Convert to dict to merge with default configs
            generation_params.update({
                k: v for k, v in model_generation_config.to_dict().items() 
                if v is not None
            })
        except Exception as config_load_error:
            logger.warning(f"Could not load model's generation config: {config_load_error}")

        # Override with request-specific parameters if provided
        if request.parameters:
            generation_params.update(request.parameters)

        logger.info(f"Final Generation Parameters: {generation_params}")

        
        full_input = "Interpret this dream: " + request.inputs
        logger.info(f"Processing input: {full_input}")
        
        logger.info("Tokenizing input...")
        inputs = tokenizer(
            full_input,
            return_tensors="pt",
            truncation=True,
            max_length=512,
            padding=True
        )
        logger.info("Input tokenized successfully")
        
        logger.info("Generating output...")
        
       # Generate with final parameters
        outputs = model.generate(
            **inputs, 
            **{k: v for k, v in generation_params.items() if k in [
                'max_length', 'min_length', 'do_sample', 'temperature', 
                'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size', 
                'repetition_penalty', 'early_stopping'
            ]}
        )
        logger.info("Output generated successfully")
        
        result = tokenizer.decode(outputs[0], skip_special_tokens=True)
        logger.info(f"Final result: {result}")
        
        return PredictionResponse(generated_text=result)
        
    except Exception as e:
        logger.error(f"Error: {str(e)}")
        logger.error(f"Error type: {type(e)}")
        import traceback
        logger.error(f"Traceback: {traceback.format_exc()}")
        raise HTTPException(status_code=500, detail=str(e))