lakshya-raj's picture
v3.0.1-Test Image Classifier with Pytorch
583f98e
raw
history blame
1.28 kB
import gradio as gr
import numpy as np
import torch
import requests
from PIL import Image
from torchvision import transforms
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
# def sepia(input_img):
# sepia_filter = np.array([
# [0.393, 0.769, 0.189],
# [0.349, 0.686, 0.168],
# [0.272, 0.534, 0.131]
# ])
# sepia_img = input_img.dot(sepia_filter.T)
# sepia_img /= sepia_img.max()
# return sepia_img
# def greet(name):
# return "Hello " + name + "!!"
def predict(inp):
inp = transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
return confidences
# demo = gr.Interface(fn=sepia, inputs="image", outputs="image")
demo = gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
examples=["lion.jpg", "cheetah.jpg"])
demo.launch()
# iface = gr.Interface(fn=greet, inputs="text", outputs="text")
# iface.launch()