Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
import albumentations | |
import pandas as pd | |
from lightning_model import LitClassification | |
# Load class labels | |
df = pd.read_csv("imagenet_class_labels.csv") | |
class_labels = df['Labels'].tolist() | |
# Initialize model and load checkpoint | |
model = LitClassification() | |
checkpoint = torch.load("bestmodel-epoch=46-monitor-val_acc1=63.7760009765625.ckpt", | |
map_location=torch.device('cpu')) # Load to CPU by default | |
model.load_state_dict(checkpoint['state_dict']) | |
model.eval() | |
# Image preprocessing | |
valid_aug = albumentations.Compose( | |
[ | |
albumentations.Resize(224, 224, p=1), | |
albumentations.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
max_pixel_value=255.0, | |
p=1.0, | |
), | |
], | |
p=1.0, | |
) | |
def preprocess_image(image): | |
# Convert to RGB if needed | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
# Convert to numpy array | |
image = np.array(image) | |
# Center crop 95% area | |
H, W, C = image.shape | |
image = image[int(0.04 * H) : int(0.96 * H), int(0.04 * W) : int(0.96 * W), :] | |
# Apply augmentations | |
augmented = valid_aug(image=image) | |
image = augmented["image"] | |
# Convert to tensor and add batch dimension | |
image = torch.tensor(image.transpose(2, 0, 1), dtype=torch.float).unsqueeze(0) | |
return image | |
def predict(image): | |
# Preprocess the image | |
processed_image = preprocess_image(image) | |
# Get model prediction | |
with torch.no_grad(): | |
outputs = model(processed_image) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
# Get top 5 predictions | |
top5_prob, top5_indices = torch.topk(probabilities, 5) | |
# Convert predictions to labels and probabilities | |
results = { | |
class_labels[idx]: float(prob) | |
for prob, idx in zip(top5_prob[0], top5_indices[0]) | |
} | |
return results | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=5), | |
examples=["sample_imgs/stock-photo-large-hot-dog.jpg"], | |
title="ImageNet Classification with ResNet50", | |
description="Upload an image to classify it into one of 1000 ImageNet categories." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |