File size: 1,907 Bytes
a66ef08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d89502
a66ef08
 
 
 
4d89502
 
 
a66ef08
4d89502
a66ef08
 
4d89502
a66ef08
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
from resnet_model import ResNet50
from utils import load_checkpoint
import ast

# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet50()
model = torch.nn.DataParallel(model)
model = model.to(device)

# Load the checkpoint
checkpoint_path = "checkpoint.pth"
model, _, _, _ = load_checkpoint(model, None, checkpoint_path)
model.eval()

# Define the image transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load class labels from the file
with open("imagenet1000_clsidx_to_labels.txt") as f:
    class_labels = ast.literal_eval(f.read())

# Define the prediction function
def predict(image):
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(image)
    probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
    top5_prob, top5_catid = torch.topk(probabilities, 5)
    
    results = "<div style='font-family: Arial, sans-serif; font-size: 18px;'>"
    for i in range(top5_prob.size(0)):
        class_index = top5_catid[i].item()
        class_label = class_labels.get(class_index, "Unknown")
        prob = top5_prob[i].item() * 100
        results += f"<div style='margin-bottom: 10px;'><strong>{class_label}</strong>: {prob:.2f}%</div>"
        results += f"<div style='background-color: #ddd; width: 100%;'><div style='width: {prob}%; background-color: #4CAF50; height: 20px;'></div></div>"
    results += "</div>"
    
    return results

# Create the Gradio interface
iface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="html", title="ResNet 50 Image Classifier")
iface.launch()