from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from transformers import AutoTokenizer from onnxruntime import InferenceSession import numpy as np import os from typing import Dict app = FastAPI(title="ONNX Model API with Tokenizer") # CORS configuration app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Initialize components tokenizer = AutoTokenizer.from_pretrained("Xenova/multi-qa-mpnet-base-dot-v1") session = InferenceSession("model.onnx") def convert_outputs(outputs): """Ensure all numpy values are converted to Python native types""" if isinstance(outputs, (np.generic, np.ndarray)): return outputs.item() if outputs.ndim == 0 else outputs.tolist() return outputs @app.post("/api/process") async def process_text(request: Dict[str, str]): try: text = request.get("text", "") # Tokenize the input text inputs = tokenizer( text, return_tensors="np", padding=True, truncation=True, max_length=32 # Match your model's expected input size ) # Convert to ONNX-compatible format onnx_inputs = { "input_ids": inputs["input_ids"].astype(np.int64), "attention_mask": inputs["attention_mask"].astype(np.int64) } # Run model inference outputs = session.run(None, onnx_inputs) # Convert all numpy types to native Python types processed_outputs = [convert_outputs(output) for output in outputs] return { "embedding": processed_outputs[0], # Assuming first output is embeddings "tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.get("/health") async def health_check(): return {"status": "healthy"}