nidra / app.py
m1k3wn's picture
Update app.py
310dbc9 verified
raw
history blame
13.6 kB
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
from typing import Optional, Dict, Any, ClassVar
import logging
import os
import sys
import traceback
from functools import lru_cache
import gc
import asyncio
from fastapi import BackgroundTasks
import psutil
# Initialize FastAPI
app = FastAPI()
# Debugging logging with detailed formatting
# logging.basicConfig(
# level=logging.DEBUG,
# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
# )
# logger = logging.getLogger(__name__)
# Get HF token
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
logger.warning("No HF_TOKEN found in environment variables")
MODELS = {
"nidra-v1": "m1k3wn/nidra-v1",
"nidra-v2": "m1k3wn/nidra-v2"
}
DEFAULT_GENERATION_CONFIGS = {
"nidra-v1": {
"max_length": 300,
"min_length": 150,
"num_beams": 8,
"temperature": 0.55,
"do_sample": True,
"top_p": 0.95,
"repetition_penalty": 4.5,
"no_repeat_ngram_size": 4,
"early_stopping": True,
"length_penalty": 1.2,
},
"nidra-v2": {
"max_length": 300,
"min_length": 150,
"num_beams": 8,
"temperature": 0.4,
"do_sample": True,
"top_p": 0.95,
"repetition_penalty": 3.5,
"no_repeat_ngram_size": 4,
"early_stopping": True,
"length_penalty": 1.2,
}
}
class ModelManager:
_instances: ClassVar[Dict[str, tuple]] = {}
_lock = asyncio.Lock() # Add lock for thread safety
@classmethod
async def get_model_and_tokenizer(cls, model_name: str):
async with cls._lock:
if model_name not in cls._instances:
try:
model_path = MODELS[model_name]
logger.debug(f"Attempting to load tokenizer from {model_path}")
try:
tokenizer = T5Tokenizer.from_pretrained(
model_path,
token=HF_TOKEN,
local_files_only=False
)
logger.debug("Tokenizer loaded successfully")
except Exception as e:
logger.error(f"Detailed tokenizer error: {str(e)}")
logger.error(f"HF_TOKEN present: {bool(HF_TOKEN)}")
raise
logger.debug("Attempting to load model")
model = T5ForConditionalGeneration.from_pretrained(
model_path,
token=HF_TOKEN,
local_files_only=False,
low_cpu_mem_usage=True,
torch_dtype=torch.float32
)
logger.debug("Model loaded successfully")
model.eval()
torch.set_num_threads(8)
cls._instances[model_name] = (model, tokenizer)
except Exception as e:
logger.error(f"Error loading {model_name}: {str(e)}")
raise
return cls._instances[model_name]
class PredictionRequest(BaseModel):
inputs: str
model: str = "nidra-v1"
parameters: Optional[Dict[str, Any]] = None
class PredictionResponse(BaseModel):
generated_text: str
selected_model: str # Changed from model_used to avoid namespace conflict
@app.get("/debug/memory")
async def memory_usage():
process = psutil.Process()
memory_info = process.memory_info()
return {
"memory_used_mb": memory_info.rss / 1024 / 1024,
"memory_percent": process.memory_percent(),
"cpu_percent": process.cpu_percent()
}
@app.get("/version")
async def version():
return {
"python_version": sys.version,
"models_available": list(MODELS.keys())
}
@app.get("/health")
async def health():
try:
await ModelManager.get_model_and_tokenizer("nidra-v1")
return {
"status": "healthy",
"loaded_models": list(ModelManager._instances.keys())
}
except Exception as e:
logger.error(f"Health check failed: {str(e)}")
return {
"status": "unhealthy",
"error": str(e)
}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest, background_tasks: BackgroundTasks):
try:
if request.model not in MODELS:
raise HTTPException(
status_code=400,
detail=f"Invalid model. Available models: {list(MODELS.keys())}"
)
model, tokenizer = await ModelManager.get_model_and_tokenizer(request.model)
generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
try:
model_generation_config = model.generation_config
generation_params.update({
k: v for k, v in model_generation_config.to_dict().items()
if v is not None
})
except Exception as config_load_error:
logger.warning(f"Using default generation config: {config_load_error}")
if request.parameters:
generation_params.update(request.parameters)
logger.debug(f"Final generation parameters: {generation_params}")
full_input = "Interpret this dream: " + request.inputs
inputs = tokenizer(
full_input,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True,
return_attention_mask=True
)
async def generate():
return model.generate(
**inputs,
**{k: v for k, v in generation_params.items() if k in [
'max_length', 'min_length', 'do_sample', 'temperature',
'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
'repetition_penalty', 'early_stopping'
]}
)
with torch.inference_mode():
outputs = await asyncio.wait_for(generate(), timeout=70.0)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
background_tasks.add_task(cleanup_memory)
return PredictionResponse(
generated_text=result,
selected_model=request.model
)
except Exception as e:
error_msg = f"Error during prediction: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
def cleanup_memory():
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
from typing import Optional, Dict, Any, ClassVar
import logging
import os
import sys
import traceback
from functools import lru_cache
import gc
import asyncio
from fastapi import BackgroundTasks
import psutil
# Initialize FastAPI
app = FastAPI()
# Set up logging with more detailed formatting
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Get HF token
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
logger.warning("No HF_TOKEN found in environment variables")
MODELS = {
"nidra-v1": "m1k3wn/nidra-v1",
"nidra-v2": "m1k3wn/nidra-v2"
}
DEFAULT_GENERATION_CONFIGS = {
"nidra-v1": {
"max_length": 300,
"min_length": 150,
"num_beams": 8,
"temperature": 0.55,
"do_sample": True,
"top_p": 0.95,
"repetition_penalty": 4.5,
"no_repeat_ngram_size": 4,
"early_stopping": True,
"length_penalty": 1.2,
},
"nidra-v2": {
"max_length": 300,
"min_length": 150,
"num_beams": 8,
"temperature": 0.4,
"do_sample": True,
"top_p": 0.95,
"repetition_penalty": 3.5,
"no_repeat_ngram_size": 4,
"early_stopping": True,
"length_penalty": 1.2,
}
}
class ModelManager:
_instances: ClassVar[Dict[str, tuple]] = {}
_lock = asyncio.Lock() # Add lock for thread safety
@classmethod
async def get_model_and_tokenizer(cls, model_name: str):
async with cls._lock:
if model_name not in cls._instances:
try:
model_path = MODELS[model_name]
tokenizer = T5Tokenizer.from_pretrained(
model_path,
token=HF_TOKEN,
local_files_only=True # Cache after first load
)
model = T5ForConditionalGeneration.from_pretrained(
model_path,
token=HF_TOKEN,
local_files_only=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float32
)
# Enable parallel processing
model.eval()
torch.set_num_threads(8) # Use all CPU cores
cls._instances[model_name] = (model, tokenizer)
except Exception as e:
logger.error(f"Error loading {model_name}: {str(e)}")
raise
return cls._instances[model_name]
class PredictionRequest(BaseModel):
inputs: str
model: str = "nidra-v1"
parameters: Optional[Dict[str, Any]] = None
class PredictionResponse(BaseModel):
generated_text: str
selected_model: str # Changed from model_used to avoid namespace conflict
@app.get("/debug/memory")
async def memory_usage():
process = psutil.Process()
memory_info = process.memory_info()
return {
"memory_used_mb": memory_info.rss / 1024 / 1024,
"memory_percent": process.memory_percent(),
"cpu_percent": process.cpu_percent()
}
@app.get("/version")
async def version():
return {
"python_version": sys.version,
"models_available": list(MODELS.keys())
}
@app.get("/health")
async def health():
try:
await ModelManager.get_model_and_tokenizer("nidra-v1")
return {
"status": "healthy",
"loaded_models": list(ModelManager._instances.keys())
}
except Exception as e:
logger.error(f"Health check failed: {str(e)}")
return {
"status": "unhealthy",
"error": str(e)
}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest, background_tasks: BackgroundTasks):
try:
if request.model not in MODELS:
raise HTTPException(
status_code=400,
detail=f"Invalid model. Available models: {list(MODELS.keys())}"
)
model, tokenizer = await ModelManager.get_model_and_tokenizer(request.model)
generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
try:
model_generation_config = model.generation_config
generation_params.update({
k: v for k, v in model_generation_config.to_dict().items()
if v is not None
})
except Exception as config_load_error:
logger.warning(f"Using default generation config: {config_load_error}")
if request.parameters:
generation_params.update(request.parameters)
logger.debug(f"Final generation parameters: {generation_params}")
full_input = "Interpret this dream: " + request.inputs
inputs = tokenizer(
full_input,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True,
return_attention_mask=True
)
async def generate():
return model.generate(
**inputs,
**{k: v for k, v in generation_params.items() if k in [
'max_length', 'min_length', 'do_sample', 'temperature',
'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
'repetition_penalty', 'early_stopping'
]}
)
with torch.inference_mode():
outputs = await asyncio.wait_for(generate(), timeout=70.0)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
background_tasks.add_task(cleanup_memory)
return PredictionResponse(
generated_text=result,
selected_model=request.model
)
except Exception as e:
error_msg = f"Error during prediction: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
def cleanup_memory():
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)