Spaces:
Sleeping
Sleeping
File size: 4,901 Bytes
2095fff 78a09b4 5c94eeb fab4412 2095fff 7394c77 a4b1bdb fab4412 a4b1bdb 2580a1e 2095fff 2580a1e 5fc0c7a 2095fff 2580a1e 7394c77 2095fff 5c94eeb 2095fff 4347c84 5c94eeb 4347c84 2095fff 2580a1e 2095fff 5c94eeb 5fc0c7a 78a09b4 9ab0a9a 2580a1e 9ab0a9a 2580a1e 9ab0a9a 5c94eeb 5fc0c7a 2095fff 10c106d 5fc0c7a 9ab0a9a 5fc0c7a 9ab0a9a 5fc0c7a 9ab0a9a 5fc0c7a 9ab0a9a 5c94eeb 9ab0a9a 78a09b4 9ab0a9a 10c106d 78a09b4 5fc0c7a 2095fff 5fc0c7a 9ab0a9a 2580a1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
from typing import Optional, Dict, Any
import logging
import os
import sys
import traceback
# Initialize FastAPI first
app = FastAPI()
# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Get HF token
HF_TOKEN = os.environ.get("HF_TOKEN")
MODELS = {
"nidra-v1": "m1k3wn/nidra-v1",
"nidra-v2": "m1k3wn/nidra-v2"
}
# Define default generation configurations for each model
DEFAULT_GENERATION_CONFIGS = {
"nidra-v1": {
"max_length": 300,
"min_length": 150,
"num_beams": 8,
"temperature": 0.55,
"do_sample": True,
"top_p": 0.95,
"repetition_penalty": 4.5,
"no_repeat_ngram_size": 4,
"early_stopping": True,
"length_penalty": 1.2,
},
"nidra-v2": {
"max_length": 300,
"min_length": 150,
"num_beams": 8,
"temperature": 0.4,
"do_sample": True,
"top_p": 0.95,
"repetition_penalty": 3.5,
"no_repeat_ngram_size": 4,
"early_stopping": True,
"length_penalty": 1.2,
}
}
class PredictionRequest(BaseModel):
inputs: str
model: str = "nidra-v1"
parameters: Optional[Dict[str, Any]] = None # Allow custom parameters
class PredictionResponse(BaseModel):
generated_text: str
@app.get("/version")
async def version():
return {"python_version": sys.version}
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
try:
# Validate model
if request.model not in MODELS:
raise HTTPException(status_code=400, detail=f"Invalid model: {request.model}")
logger.info(f"Loading model: {request.model}")
model_path = MODELS[request.model]
# Add debug logging
logger.info("Attempting to load tokenizer...")
tokenizer = T5Tokenizer.from_pretrained(
model_path,
token=HF_TOKEN,
local_files_only=False,
return_special_tokens_mask=True
)
logger.info("Tokenizer loaded successfully")
logger.info("Attempting to load model...")
model = T5ForConditionalGeneration.from_pretrained(
model_path,
token=HF_TOKEN,
local_files_only=False
)
logger.info("Model loaded successfully")
# Priority: 1. Request parameters, 2. Model's saved generation_config, 3. Default configs
generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
# Try to load model's saved generation config
try:
model_generation_config = GenerationConfig.from_pretrained(model_path)
# Convert to dict to merge with default configs
generation_params.update({
k: v for k, v in model_generation_config.to_dict().items()
if v is not None
})
except Exception as config_load_error:
logger.warning(f"Could not load model's generation config: {config_load_error}")
# Override with request-specific parameters if provided
if request.parameters:
generation_params.update(request.parameters)
logger.info(f"Final Generation Parameters: {generation_params}")
full_input = "Interpret this dream: " + request.inputs
logger.info(f"Processing input: {full_input}")
logger.info("Tokenizing input...")
inputs = tokenizer(
full_input,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
logger.info("Input tokenized successfully")
logger.info("Generating output...")
# Generate with final parameters
outputs = model.generate(
**inputs,
**{k: v for k, v in generation_params.items() if k in [
'max_length', 'min_length', 'do_sample', 'temperature',
'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
'repetition_penalty', 'early_stopping'
]}
)
logger.info("Output generated successfully")
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info(f"Final result: {result}")
return PredictionResponse(generated_text=result)
except Exception as e:
logger.error(f"Error: {str(e)}")
logger.error(f"Error type: {type(e)}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(e)) |