File size: 2,403 Bytes
0a216f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
import torch
import numpy as np
from PIL import Image
import albumentations
import pandas as pd
from lightning_model import LitClassification

# Load class labels
df = pd.read_csv("imagenet_class_labels.csv")
class_labels = df['Labels'].tolist()

# Initialize model and load checkpoint
model = LitClassification()
checkpoint = torch.load("bestmodel-epoch=46-monitor-val_acc1=63.7760009765625.ckpt", 
                       map_location=torch.device('cpu'))  # Load to CPU by default
model.load_state_dict(checkpoint['state_dict'])
model.eval()

# Image preprocessing
valid_aug = albumentations.Compose(
    [
        albumentations.Resize(224, 224, p=1),
        albumentations.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255.0,
            p=1.0,
        ),
    ],
    p=1.0,
)

def preprocess_image(image):
    # Convert to RGB if needed
    if image.mode != "RGB":
        image = image.convert("RGB")
    
    # Convert to numpy array
    image = np.array(image)
    
    # Center crop 95% area
    H, W, C = image.shape
    image = image[int(0.04 * H) : int(0.96 * H), int(0.04 * W) : int(0.96 * W), :]
    
    # Apply augmentations
    augmented = valid_aug(image=image)
    image = augmented["image"]
    
    # Convert to tensor and add batch dimension
    image = torch.tensor(image.transpose(2, 0, 1), dtype=torch.float).unsqueeze(0)
    return image

def predict(image):
    # Preprocess the image
    processed_image = preprocess_image(image)
    
    # Get model prediction
    with torch.no_grad():
        outputs = model(processed_image)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        
    # Get top 5 predictions
    top5_prob, top5_indices = torch.topk(probabilities, 5)
    
    # Convert predictions to labels and probabilities
    results = {
        class_labels[idx]: float(prob)
        for prob, idx in zip(top5_prob[0], top5_indices[0])
    }
    
    return results

# Create Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(num_top_classes=5),
    examples=["sample_imgs/stock-photo-large-hot-dog.jpg"],
    title="ImageNet Classification with ResNet50",
    description="Upload an image to classify it into one of 1000 ImageNet categories."
)

if __name__ == "__main__":
    iface.launch()