import mxnet as mx import matplotlib.pyplot as plt import numpy as np from collections import namedtuple from mxnet.gluon.data.vision import transforms import os import gradio as gr from PIL import Image import imageio import onnxruntime as ort from torchvision import transforms preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg') mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt') with open('synset.txt', 'r') as f: labels = [l.rstrip() for l in f] os.system("wget https://github.com/AK391/models/raw/main/vision/classification/densenet-121/model/densenet-9.onnx") ort_session = ort.InferenceSession("densenet-9.onnx") def predict(pil): input_tensor = preprocess(pil) img_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model img_batch_np = img_batch.cpu().detach().numpy() outputs = ort_session.run( None, {"data_0": img_batch_np.astype(np.float32)}, ) a = np.argsort(outputs[0].flatten()) results = {} for i in a[0:5]: results[labels[i]]=float(outputs[0][0][i]) return results title="DenseNet-121" description="DenseNet-121 is a convolutional neural network for classification." examples=[['apple.jpg']] gr.Interface(predict,gr.inputs.Image(type='pil'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)