chryzxc's picture
Update app.py
1e9ac73 verified
raw
history blame
900 Bytes
import gradio as gr
import numpy as np
import onnxruntime as ort
# Load the ONNX model
session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
# Prediction function
def predict(input_ids: list[int], attention_mask: list[int]):
# Convert to numpy arrays and batch them
input_ids_np = np.array([input_ids], dtype=np.int64)
attention_mask_np = np.array([attention_mask], dtype=np.int64)
# Run the model
outputs = session.run(None, {
"input_ids": input_ids_np,
"attention_mask": attention_mask_np
})
# Return raw outputs or post-process as needed
return outputs
# Expose API endpoint
demo = gr.Interface(
fn=predict,
inputs=[
gr.JSON(label="input_ids"),
gr.JSON(label="attention_mask")
],
outputs="json",
allow_flagging="never"
)
app = gr.mount_gradio_app(app=None, blocks=demo, path="/")