Trial / app.py
pavi156's picture
Update app.py
c5e43e4
raw
history blame
985 Bytes
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
# Load the trained model
model_path = "cifar_net.pth"
model = torch.load(model_path)
model.eval()
# Prepare the image for prediction
image_path = 'download.jpg'
image = Image.open(image_path)
# Transform the image to match CIFAR-10 format
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize with CIFAR-10 mean and std
])
input_image = transform(image).unsqueeze(0)
# Make predictions
with torch.no_grad():
outputs = model(input_image)
# Retrieve the predicted class label
_, predicted = torch.max(outputs, 1)
class_index = predicted.item()
# Load the CIFAR-10 class labels
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# Print the predicted class label
print('Predicted class label:', classes[class_index])