import gradio as gr import numpy as np import onnxruntime as ort # Load the ONNX model session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"]) # Prediction function def predict(input_ids: list[int], attention_mask: list[int]): # Convert to numpy arrays and batch them input_ids_np = np.array([input_ids], dtype=np.int64) attention_mask_np = np.array([attention_mask], dtype=np.int64) # Run the model outputs = session.run(None, { "input_ids": input_ids_np, "attention_mask": attention_mask_np }) # Return raw outputs or post-process as needed return outputs # Expose API endpoint demo = gr.Interface( fn=predict, inputs=[ gr.JSON(label="input_ids"), gr.JSON(label="attention_mask") ], outputs="json", allow_flagging="never" ) app = gr.mount_gradio_app(app=None, blocks=demo, path="/")