File size: 2,540 Bytes
dce9fbe
9c1f464
dce9fbe
9c1f464
 
dce9fbe
9c1f464
08a94a9
9c1f464
08a94a9
 
 
dce9fbe
319ddab
9c1f464
 
 
dce9fbe
 
 
 
 
0a04b24
 
 
 
 
 
 
 
 
 
dce9fbe
9c1f464
 
 
08a94a9
ffad640
9c1f464
 
dce9fbe
9c1f464
dce9fbe
 
 
9c1f464
 
 
dce9fbe
9c1f464
 
 
 
 
dce9fbe
08a94a9
dce9fbe
9c1f464
08a94a9
 
 
9c1f464
 
 
dce9fbe
08a94a9
dce9fbe
08a94a9
 
dce9fbe
 
08a94a9
 
dce9fbe
08a94a9
 
 
dce9fbe
08a94a9
 
0a04b24
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import torch
import gradio as gr
from PIL import Image
from torchvision import transforms
from timeit import default_timer as timer
from torch.nn import functional as F
from gradio.flagging import SimpleCSVLogger

torch.set_float32_matmul_precision("medium")
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu") 
torch.set_default_device(device=device)
# torch.autocast(enabled=True, dtype="float16", device_type="cuda")


TEST_TRANSFORMS = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class_labels = [
    "Beagle",
    "Boxer",
    "Bulldog",
    "Dachshund",
    "German_Shepherd",
    "Golden_Retriever",
    "Labrador_Retriever",
    "Poodle",
    "Rottweiler",
    "Yorkshire_Terrier",
]


# Model
model:torch.nn.Module = torch.jit.load("best_model.pt", map_location=device).to(device)


@torch.no_grad()
def predict_fn(img: Image):
    start_time = timer()
    try:
        # img = np.array(img)
        # print(img)
        img = TEST_TRANSFORMS(img).to(device)
        # print(type(img),img.shape)
        logits = model(img.unsqueeze(0))
        probabilities = F.softmax(logits, dim=-1)
        # print(torch.topk(probabilities,k=2))
        y_pred = probabilities.argmax(dim=-1).item()
        confidence = probabilities[0][y_pred].item()
        predicted_label = class_labels[y_pred]
        # print(confidence,predicted_label)
        pred_time = round(timer() - start_time, 5)
        res = {f"Title: {predicted_label}": confidence}
        return (res, pred_time)
    except Exception as e:
        print(f"error:: {e}")
        gr.Error("An error occured 💥!", duration=5)
        return ({"Title ☠️": 0.0}, 0.0)


gr.Interface(
    fn=predict_fn,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Label(num_top_classes=1, label="Predictions"),  # what are the outputs?
        gr.Number(label="Prediction time (s)"),
    ],
    examples=[
        ["examples/" + i]
        for i in os.listdir(os.path.join(os.path.dirname(__file__), "examples"))
    ],
    title="Dog Breeds Classifier 🐈",
    description="CNN-based Architecture for Fast and Accurate DogsBreed Classifier",
    article="Created by muthukamalan.m ❤️",
    cache_examples=True,
    flagging_options=[],
    flagging_callback=SimpleCSVLogger()
).launch(share=False, debug=False,server_name="0.0.0.0",server_port=7860,enable_monitoring=None)