|
import gradio as gr |
|
import torch |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
|
|
|
|
model_path = "cifar_net.pth" |
|
|
|
model = torch.load(model_path) |
|
model.eval() |
|
|
|
|
|
image_path = 'download.jpg' |
|
image = Image.open(image_path) |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((32, 32)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
input_image = transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(input_image) |
|
|
|
|
|
_, predicted = torch.max(outputs, 1) |
|
class_index = predicted.item() |
|
|
|
|
|
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] |
|
|
|
|
|
print('Predicted class label:', classes[class_index]) |