import os import sys current = os.path.dirname(os.path.realpath(__file__)) parent = os.path.dirname(current) sys.path.append(parent) import albumentations as A import gradio as gr import matplotlib.pyplot as plt import numpy as np import torch from albumentations.pytorch import ToTensorV2 from PIL import Image from model import Classifier # Load the model model = Classifier.load_from_checkpoint("./models/checkpoint.ckpt") model.eval() # Define labels labels = [ "dog", "horse", "elephant", "butterfly", "chicken", "cat", "cow", "sheep", "spider", "squirrel", ] # Preprocess function def preprocess(image): image = np.array(image) resize = A.Resize(224, 224) normalize = A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) to_tensor = ToTensorV2() transform = A.Compose([resize, normalize, to_tensor]) image = transform(image=image)["image"] return image # Define the sample images sample_images = { "dog": "./test_images/dog.jpeg", "cat": "./test_images/cat.jpeg", "butterfly": "./test_images/butterfly.jpeg", "elephant": "./test_images/elephant.jpg", "horse": "./test_images/horse.jpeg", } # Define the function to make predictions on an image def predict(image): try: image = preprocess(image).unsqueeze(0) # Prediction # Make a prediction on the image with torch.no_grad(): output = model(image) # convert to probabilities probabilities = torch.nn.functional.softmax(torch.exp(output[0]), dim=0) topk_prob, topk_label = torch.topk(probabilities, 3) # Return the top 3 predictions return {labels[i]: float(probabilities[i]) for i in range(3)} except Exception as e: print(f"Error predicting image: {e}") return [] # Define the interface def app(): title = "Animal-10 Image Classification" description = "Classify images using a custom CNN model and deploy using Gradio." gr.Interface( title=title, description=description, fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label( num_top_classes=3, ), examples=[ "./test_images/dog.jpeg", "./test_images/cat.jpeg", "./test_images/butterfly.jpeg", "./test_images/elephant.jpg", "./test_images/horse.jpeg", ], ).launch() # Run the app if __name__ == "__main__": app()