File size: 7,676 Bytes
84f4f1d
1ec19be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42717bf
abb0e4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7ed1d0
 
 
abb0e4d
f7ed1d0
1ec19be
42717bf
 
 
 
 
 
 
 
3d1738d
42717bf
 
 
 
 
3d1738d
 
 
42717bf
 
 
3d1738d
42717bf
 
 
 
 
 
 
f7ed1d0
abb0e4d
 
 
 
 
 
 
 
4347c84
42717bf
16c0a32
 
 
 
f7ed1d0
16c0a32
 
 
f7ed1d0
2580a1e
42717bf
abb0e4d
 
 
 
 
 
 
42717bf
abb0e4d
 
 
42717bf
 
 
 
 
 
 
abb0e4d
 
 
 
 
42717bf
 
abb0e4d
 
 
 
 
 
1ec19be
abb0e4d
 
 
 
 
 
 
1ec19be
3d1738d
 
 
 
 
 
abb0e4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ec19be
 
abb0e4d
 
1ec19be
3d1738d
 
 
 
 
 
 
 
 
 
 
 
 
 
1ec19be
 
3d1738d
1ec19be
abb0e4d
1ec19be
abb0e4d
 
 
 
 
 
3d1738d
 
 
 
 
 
abb0e4d
 
 
3d1738d
 
 
abb0e4d
 
1ec19be
3d1738d
 
 
 
 
 
 
 
 
 
1ec19be
abb0e4d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
from typing import Optional, Dict, Any, ClassVar
import logging
import os
import sys
import traceback
from functools import lru_cache
import gc
import asyncio
from fastapi import BackgroundTasks
import psutil

# Initialize FastAPI
app = FastAPI()

# Debugging logs
logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Get HF token
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
    logger.warning("No HF_TOKEN found in environment variables")

MODELS = {
    "nidra-v1": "m1k3wn/nidra-v1",
    "nidra-v2": "m1k3wn/nidra-v2"
}

DEFAULT_GENERATION_CONFIGS = {
    "nidra-v1": {
        "max_length": 300,
        "min_length": 150,
        "num_beams": 8,
        "temperature": 0.55,
        "do_sample": True,
        "top_p": 0.95,
        "repetition_penalty": 4.5,
        "no_repeat_ngram_size": 4,
        "early_stopping": True,
        "length_penalty": 1.2,
    },
    "nidra-v2": {
        "max_length": 300,
        "min_length": 150,
        "num_beams": 8,
        "temperature": 0.4,
        "do_sample": True,
        "top_p": 0.95,
        "repetition_penalty": 3.5,
        "no_repeat_ngram_size": 4,
        "early_stopping": True,
        "length_penalty": 1.2,
    }
}

class ModelManager:
    _instances: ClassVar[Dict[str, tuple]] = {}

    @classmethod
    async def get_model_and_tokenizer(cls, model_name: str):
        if model_name not in cls._instances:
            try:
                model_path = MODELS[model_name]
                logger.debug(f"Loading tokenizer and model from {model_path}")
                
                tokenizer = T5Tokenizer.from_pretrained(
                    model_path,
                    token=HF_TOKEN,
                    use_fast=True
                )
                
                model = T5ForConditionalGeneration.from_pretrained(
                    model_path,
                    token=HF_TOKEN,
                    torch_dtype=torch.float32,
                    low_cpu_mem_usage=True,
                    device_map='auto'
                )
                
                model.eval()
                torch.set_num_threads(6)  # Number of CPUs used
                cls._instances[model_name] = (model, tokenizer)
                
            except Exception as e:
                logger.error(f"Error loading {model_name}: {str(e)}")
                raise
                
        return cls._instances[model_name]

class PredictionRequest(BaseModel):
    inputs: str
    model: str = "nidra-v1"
    parameters: Optional[Dict[str, Any]] = None

class PredictionResponse(BaseModel):
    generated_text: str
    selected_model: str  # Changed from model_used to avoid namespace conflict

# Memory debug endpoint
@app.get("/debug/memory")
async def memory_usage():
    process = psutil.Process()
    memory_info = process.memory_info()
    return {
        "memory_used_mb": memory_info.rss / 1024 / 1024,
        "memory_percent": process.memory_percent(),
        "cpu_percent": process.cpu_percent()
    }

# Version check
@app.get("/version")
async def version():
    return {
        "python_version": sys.version,
        "models_available": list(MODELS.keys())
    }

# Healthcheck endpoint
@app.get("/health")
async def health():
    try:
        logger.debug("Health check started")
        logger.debug(f"HF_TOKEN present: {bool(HF_TOKEN)}")
        logger.debug(f"Available models: {MODELS}")
        
        result = await ModelManager.get_model_and_tokenizer("nidra-v1")
        logger.debug("Model and tokenizer loaded successfully")
        
        return {
            "status": "healthy",
            "loaded_models": list(ModelManager._instances.keys())
        }
    except Exception as e:
        error_msg = f"Health check failed: {str(e)}\n{traceback.format_exc()}"
        logger.error(error_msg)
        return {
            "status": "unhealthy",
            "error": str(e)
        }

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest, background_tasks: BackgroundTasks):
    try:
        if request.model not in MODELS:
            raise HTTPException(
                status_code=400,
                detail=f"Invalid model. Available models: {list(MODELS.keys())}"
            )

        model, tokenizer = await ModelManager.get_model_and_tokenizer(request.model)
        
        # Add immediate cleanup of memory before generation
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()

        try:
            model_generation_config = model.generation_config
            generation_params.update({
                k: v for k, v in model_generation_config.to_dict().items()
                if v is not None
            })
        except Exception as config_load_error:
            logger.warning(f"Using default generation config: {config_load_error}")

        if request.parameters:
            generation_params.update(request.parameters)

        logger.debug(f"Final generation parameters: {generation_params}")

        full_input = "Interpret this dream: " + request.inputs
        inputs = tokenizer(
            full_input,
            return_tensors="pt",
            truncation=True,
            max_length=512,
            padding=True,
            return_attention_mask=True
        )

        async def generate():
            try:
                return model.generate(
                    **inputs,
                    **{k: v for k, v in generation_params.items() if k in [
                        'max_length', 'min_length', 'do_sample', 'temperature',
                        'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
                        'repetition_penalty', 'early_stopping'
                    ]}
                )
            finally:
                # Ensure cleanup happens even if generation fails
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
            
        with torch.inference_mode():
            outputs = await asyncio.wait_for(generate(), timeout=45.0)  # Reduced timeout
            
        result = tokenizer.decode(outputs[0], skip_special_tokens=True)
        background_tasks.add_task(cleanup_memory)
        
        return PredictionResponse(
            generated_text=result,
            selected_model=request.model
        )

    except asyncio.TimeoutError:
        logger.error("Generation timed out")
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        raise HTTPException(status_code=504, detail="Generation timed out")
    except Exception as e:
        error_msg = f"Error during prediction: {str(e)}\n{traceback.format_exc()}"
        logger.error(error_msg)
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        raise HTTPException(status_code=500, detail=error_msg)

def cleanup_memory():
    try:
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Force Python garbage collection
        gc.collect(generation=2)
        
    except Exception as e:
        logger.error(f"Error in cleanup: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)