File size: 900 Bytes
1e9ac73 a3e1970 1e9ac73 a3e1970 1e9ac73 a3e1970 1e9ac73 a3e1970 1e9ac73 a3e1970 1e9ac73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import gradio as gr
import numpy as np
import onnxruntime as ort
# Load the ONNX model
session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
# Prediction function
def predict(input_ids: list[int], attention_mask: list[int]):
# Convert to numpy arrays and batch them
input_ids_np = np.array([input_ids], dtype=np.int64)
attention_mask_np = np.array([attention_mask], dtype=np.int64)
# Run the model
outputs = session.run(None, {
"input_ids": input_ids_np,
"attention_mask": attention_mask_np
})
# Return raw outputs or post-process as needed
return outputs
# Expose API endpoint
demo = gr.Interface(
fn=predict,
inputs=[
gr.JSON(label="input_ids"),
gr.JSON(label="attention_mask")
],
outputs="json",
allow_flagging="never"
)
app = gr.mount_gradio_app(app=None, blocks=demo, path="/")
|