nidra / app.py
Last commit not found
raw
history blame
4.9 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
from typing import Optional, Dict, Any
import logging
import os
import sys
import traceback
# Initialize FastAPI first
app = FastAPI()
# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Get HF token
HF_TOKEN = os.environ.get("HF_TOKEN")
MODELS = {
"nidra-v1": "m1k3wn/nidra-v1",
"nidra-v2": "m1k3wn/nidra-v2"
}
# Define default generation configurations for each model
DEFAULT_GENERATION_CONFIGS = {
"nidra-v1": {
"max_length": 300,
"min_length": 150,
"num_beams": 8,
"temperature": 0.55,
"do_sample": True,
"top_p": 0.95,
"repetition_penalty": 4.5,
"no_repeat_ngram_size": 4,
"early_stopping": True,
"length_penalty": 1.2,
},
"nidra-v2": {
"max_length": 300,
"min_length": 150,
"num_beams": 8,
"temperature": 0.4,
"do_sample": True,
"top_p": 0.95,
"repetition_penalty": 3.5,
"no_repeat_ngram_size": 4,
"early_stopping": True,
"length_penalty": 1.2,
}
}
class PredictionRequest(BaseModel):
inputs: str
model: str = "nidra-v1"
parameters: Optional[Dict[str, Any]] = None # Allow custom parameters
class PredictionResponse(BaseModel):
generated_text: str
@app.get("/version")
async def version():
return {"python_version": sys.version}
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
try:
# Validate model
if request.model not in MODELS:
raise HTTPException(status_code=400, detail=f"Invalid model: {request.model}")
logger.info(f"Loading model: {request.model}")
model_path = MODELS[request.model]
# Add debug logging
logger.info("Attempting to load tokenizer...")
tokenizer = T5Tokenizer.from_pretrained(
model_path,
token=HF_TOKEN,
local_files_only=False,
return_special_tokens_mask=True
)
logger.info("Tokenizer loaded successfully")
logger.info("Attempting to load model...")
model = T5ForConditionalGeneration.from_pretrained(
model_path,
token=HF_TOKEN,
local_files_only=False
)
logger.info("Model loaded successfully")
# Priority: 1. Request parameters, 2. Model's saved generation_config, 3. Default configs
generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
# Try to load model's saved generation config
try:
model_generation_config = GenerationConfig.from_pretrained(model_path)
# Convert to dict to merge with default configs
generation_params.update({
k: v for k, v in model_generation_config.to_dict().items()
if v is not None
})
except Exception as config_load_error:
logger.warning(f"Could not load model's generation config: {config_load_error}")
# Override with request-specific parameters if provided
if request.parameters:
generation_params.update(request.parameters)
logger.info(f"Final Generation Parameters: {generation_params}")
full_input = "Interpret this dream: " + request.inputs
logger.info(f"Processing input: {full_input}")
logger.info("Tokenizing input...")
inputs = tokenizer(
full_input,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
logger.info("Input tokenized successfully")
logger.info("Generating output...")
# Generate with final parameters
outputs = model.generate(
**inputs,
**{k: v for k, v in generation_params.items() if k in [
'max_length', 'min_length', 'do_sample', 'temperature',
'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
'repetition_penalty', 'early_stopping'
]}
)
logger.info("Output generated successfully")
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info(f"Final result: {result}")
return PredictionResponse(generated_text=result)
except Exception as e:
logger.error(f"Error: {str(e)}")
logger.error(f"Error type: {type(e)}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(e))