nidra / app.py
m1k3wn's picture
Update app.py
2580a1e verified
raw
history blame
2.61 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import T5Tokenizer, T5ForConditionalGeneration
import logging
import os
import sys
# Initialize FastAPI first
app = FastAPI()
# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Get HF token
HF_TOKEN = os.environ.get("HF_TOKEN")
MODELS = {
"nidra-v1": "m1k3wn/nidra-v1",
"nidra-v2": "m1k3wn/nidra-v2"
}
class PredictionRequest(BaseModel):
inputs: str
model: str = "nidra-v1"
class PredictionResponse(BaseModel):
generated_text: str
@app.get("/version")
async def version():
return {"python_version": sys.version}
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
try:
logger.info(f"Loading model: {request.model}")
model_path = MODELS[request.model]
# Add debug logging
logger.info("Attempting to load tokenizer...")
tokenizer = T5Tokenizer.from_pretrained(
model_path,
token=HF_TOKEN,
local_files_only=False,
return_special_tokens_mask=True
)
logger.info("Tokenizer loaded successfully")
logger.info("Attempting to load model...")
model = T5ForConditionalGeneration.from_pretrained(
model_path,
token=HF_TOKEN,
local_files_only=False
)
logger.info("Model loaded successfully")
full_input = "Interpret this dream: " + request.inputs
logger.info(f"Processing input: {full_input}")
logger.info("Tokenizing input...")
inputs = tokenizer(
full_input,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
logger.info("Input tokenized successfully")
logger.info("Generating output...")
outputs = model.generate(**inputs, max_length=200)
logger.info("Output generated successfully")
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info(f"Final result: {result}")
return PredictionResponse(generated_text=result)
except Exception as e:
logger.error(f"Error: {str(e)}")
logger.error(f"Error type: {type(e)}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(e))