import gradio as gr import torch import torchvision.transforms as transforms from PIL import Image # Load the trained model model_path = "cifar_net.pth" model = torch.load(model_path, map_location=torch.device('cpu')) model = YourModelClass() # Replace YourModelClass with the appropriate model class model.load_state_dict(state_dict) model.eval() # Define class labels for CIFAR-10 classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') def classify_image(image): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) # Preprocess the input image image = transform(image).unsqueeze(0) # Perform inference with the model outputs = model(image) _, predicted = torch.max(outputs, 1) predicted_class = classes[predicted.item()] return predicted_class def classify_images(images): return [classify_image(image) for image in images] inputs_image = gr.inputs.Image(label="Input Image", type="pil") outputs_image = gr.outputs.Label(label="Predicted Class") interface_image = gr.Interface( fn=classify_images, inputs=inputs_image, outputs=outputs_image, title="CIFAR-10 Image Classifier", description="Classify images into one of the CIFAR-10 classes.", examples=[ ['image_0.jpg'], ['image_1.jpg'] ], allow_flagging=False ) if __name__ == "__main__": interface_image.launch()