|
import gradio as gr |
|
import numpy as np |
|
import onnxruntime as ort |
|
|
|
|
|
session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"]) |
|
|
|
|
|
def predict(input_ids: list[int], attention_mask: list[int]): |
|
|
|
input_ids_np = np.array([input_ids], dtype=np.int64) |
|
attention_mask_np = np.array([attention_mask], dtype=np.int64) |
|
|
|
|
|
outputs = session.run(None, { |
|
"input_ids": input_ids_np, |
|
"attention_mask": attention_mask_np |
|
}) |
|
|
|
|
|
return outputs |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.JSON(label="input_ids"), |
|
gr.JSON(label="attention_mask") |
|
], |
|
outputs="json", |
|
allow_flagging="never" |
|
) |
|
|
|
app = gr.mount_gradio_app(app=None, blocks=demo, path="/") |
|
|