nishantb06 commited on
Commit
0a216f0
·
verified ·
1 Parent(s): 0a751cc
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import albumentations
6
+ import pandas as pd
7
+ from lightning_model import LitClassification
8
+
9
+ # Load class labels
10
+ df = pd.read_csv("imagenet_class_labels.csv")
11
+ class_labels = df['Labels'].tolist()
12
+
13
+ # Initialize model and load checkpoint
14
+ model = LitClassification()
15
+ checkpoint = torch.load("bestmodel-epoch=46-monitor-val_acc1=63.7760009765625.ckpt",
16
+ map_location=torch.device('cpu')) # Load to CPU by default
17
+ model.load_state_dict(checkpoint['state_dict'])
18
+ model.eval()
19
+
20
+ # Image preprocessing
21
+ valid_aug = albumentations.Compose(
22
+ [
23
+ albumentations.Resize(224, 224, p=1),
24
+ albumentations.Normalize(
25
+ mean=[0.485, 0.456, 0.406],
26
+ std=[0.229, 0.224, 0.225],
27
+ max_pixel_value=255.0,
28
+ p=1.0,
29
+ ),
30
+ ],
31
+ p=1.0,
32
+ )
33
+
34
+ def preprocess_image(image):
35
+ # Convert to RGB if needed
36
+ if image.mode != "RGB":
37
+ image = image.convert("RGB")
38
+
39
+ # Convert to numpy array
40
+ image = np.array(image)
41
+
42
+ # Center crop 95% area
43
+ H, W, C = image.shape
44
+ image = image[int(0.04 * H) : int(0.96 * H), int(0.04 * W) : int(0.96 * W), :]
45
+
46
+ # Apply augmentations
47
+ augmented = valid_aug(image=image)
48
+ image = augmented["image"]
49
+
50
+ # Convert to tensor and add batch dimension
51
+ image = torch.tensor(image.transpose(2, 0, 1), dtype=torch.float).unsqueeze(0)
52
+ return image
53
+
54
+ def predict(image):
55
+ # Preprocess the image
56
+ processed_image = preprocess_image(image)
57
+
58
+ # Get model prediction
59
+ with torch.no_grad():
60
+ outputs = model(processed_image)
61
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
62
+
63
+ # Get top 5 predictions
64
+ top5_prob, top5_indices = torch.topk(probabilities, 5)
65
+
66
+ # Convert predictions to labels and probabilities
67
+ results = {
68
+ class_labels[idx]: float(prob)
69
+ for prob, idx in zip(top5_prob[0], top5_indices[0])
70
+ }
71
+
72
+ return results
73
+
74
+ # Create Gradio interface
75
+ iface = gr.Interface(
76
+ fn=predict,
77
+ inputs=gr.Image(type="pil"),
78
+ outputs=gr.Label(num_top_classes=5),
79
+ examples=["sample_imgs/stock-photo-large-hot-dog.jpg"],
80
+ title="ImageNet Classification with ResNet50",
81
+ description="Upload an image to classify it into one of 1000 ImageNet categories."
82
+ )
83
+
84
+ if __name__ == "__main__":
85
+ iface.launch()