Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
import logging | |
import os | |
import sys | |
# Initialize FastAPI first | |
app = FastAPI() | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
# Get HF token | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
MODELS = { | |
"nidra-v1": "m1k3wn/nidra-v1", | |
"nidra-v2": "m1k3wn/nidra-v2" | |
} | |
class PredictionRequest(BaseModel): | |
inputs: str | |
model: str = "nidra-v1" | |
class PredictionResponse(BaseModel): | |
generated_text: str | |
async def version(): | |
return {"python_version": sys.version} | |
async def health(): | |
return {"status": "healthy"} | |
async def predict(request: PredictionRequest): | |
try: | |
logger.info(f"Loading model: {request.model}") | |
model_path = MODELS[request.model] | |
# Add debug logging | |
logger.info("Attempting to load tokenizer...") | |
tokenizer = T5Tokenizer.from_pretrained( | |
model_path, | |
token=HF_TOKEN, | |
local_files_only=False, | |
return_special_tokens_mask=True | |
) | |
logger.info("Tokenizer loaded successfully") | |
logger.info("Attempting to load model...") | |
model = T5ForConditionalGeneration.from_pretrained( | |
model_path, | |
token=HF_TOKEN, | |
local_files_only=False | |
) | |
logger.info("Model loaded successfully") | |
full_input = "Interpret this dream: " + request.inputs | |
logger.info(f"Processing input: {full_input}") | |
logger.info("Tokenizing input...") | |
inputs = tokenizer( | |
full_input, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, | |
padding=True | |
) | |
logger.info("Input tokenized successfully") | |
logger.info("Generating output...") | |
outputs = model.generate(**inputs, max_length=200) | |
logger.info("Output generated successfully") | |
result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
logger.info(f"Final result: {result}") | |
return PredictionResponse(generated_text=result) | |
except Exception as e: | |
logger.error(f"Error: {str(e)}") | |
logger.error(f"Error type: {type(e)}") | |
import traceback | |
logger.error(f"Traceback: {traceback.format_exc()}") | |
raise HTTPException(status_code=500, detail=str(e)) |