Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Last commit not found
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig | |
from typing import Optional, Dict, Any | |
import logging | |
import os | |
import sys | |
import traceback | |
# Initialize FastAPI first | |
app = FastAPI() | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
# Get HF token | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
MODELS = { | |
"nidra-v1": "m1k3wn/nidra-v1", | |
"nidra-v2": "m1k3wn/nidra-v2" | |
} | |
# Define default generation configurations for each model | |
DEFAULT_GENERATION_CONFIGS = { | |
"nidra-v1": { | |
"max_length": 300, | |
"min_length": 150, | |
"num_beams": 8, | |
"temperature": 0.55, | |
"do_sample": True, | |
"top_p": 0.95, | |
"repetition_penalty": 4.5, | |
"no_repeat_ngram_size": 4, | |
"early_stopping": True, | |
"length_penalty": 1.2, | |
}, | |
"nidra-v2": { | |
"max_length": 300, | |
"min_length": 150, | |
"num_beams": 8, | |
"temperature": 0.4, | |
"do_sample": True, | |
"top_p": 0.95, | |
"repetition_penalty": 3.5, | |
"no_repeat_ngram_size": 4, | |
"early_stopping": True, | |
"length_penalty": 1.2, | |
} | |
} | |
class PredictionRequest(BaseModel): | |
inputs: str | |
model: str = "nidra-v1" | |
parameters: Optional[Dict[str, Any]] = None # Allow custom parameters | |
class PredictionResponse(BaseModel): | |
generated_text: str | |
async def version(): | |
return {"python_version": sys.version} | |
async def health(): | |
return {"status": "healthy"} | |
async def predict(request: PredictionRequest): | |
try: | |
# Validate model | |
if request.model not in MODELS: | |
raise HTTPException(status_code=400, detail=f"Invalid model: {request.model}") | |
logger.info(f"Loading model: {request.model}") | |
model_path = MODELS[request.model] | |
# Add debug logging | |
logger.info("Attempting to load tokenizer...") | |
tokenizer = T5Tokenizer.from_pretrained( | |
model_path, | |
token=HF_TOKEN, | |
local_files_only=False, | |
return_special_tokens_mask=True | |
) | |
logger.info("Tokenizer loaded successfully") | |
logger.info("Attempting to load model...") | |
model = T5ForConditionalGeneration.from_pretrained( | |
model_path, | |
token=HF_TOKEN, | |
local_files_only=False | |
) | |
logger.info("Model loaded successfully") | |
# Priority: 1. Request parameters, 2. Model's saved generation_config, 3. Default configs | |
generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy() | |
# Try to load model's saved generation config | |
try: | |
model_generation_config = GenerationConfig.from_pretrained(model_path) | |
# Convert to dict to merge with default configs | |
generation_params.update({ | |
k: v for k, v in model_generation_config.to_dict().items() | |
if v is not None | |
}) | |
except Exception as config_load_error: | |
logger.warning(f"Could not load model's generation config: {config_load_error}") | |
# Override with request-specific parameters if provided | |
if request.parameters: | |
generation_params.update(request.parameters) | |
logger.info(f"Final Generation Parameters: {generation_params}") | |
full_input = "Interpret this dream: " + request.inputs | |
logger.info(f"Processing input: {full_input}") | |
logger.info("Tokenizing input...") | |
inputs = tokenizer( | |
full_input, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, | |
padding=True | |
) | |
logger.info("Input tokenized successfully") | |
logger.info("Generating output...") | |
# Generate with final parameters | |
outputs = model.generate( | |
**inputs, | |
**{k: v for k, v in generation_params.items() if k in [ | |
'max_length', 'min_length', 'do_sample', 'temperature', | |
'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size', | |
'repetition_penalty', 'early_stopping' | |
]} | |
) | |
logger.info("Output generated successfully") | |
result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
logger.info(f"Final result: {result}") | |
return PredictionResponse(generated_text=result) | |
except Exception as e: | |
logger.error(f"Error: {str(e)}") | |
logger.error(f"Error type: {type(e)}") | |
import traceback | |
logger.error(f"Traceback: {traceback.format_exc()}") | |
raise HTTPException(status_code=500, detail=str(e)) |