Spaces:
Sleeping
Sleeping
File size: 2,540 Bytes
dce9fbe 9c1f464 dce9fbe 9c1f464 dce9fbe 9c1f464 08a94a9 9c1f464 08a94a9 dce9fbe 319ddab 9c1f464 dce9fbe 0a04b24 dce9fbe 9c1f464 08a94a9 ffad640 9c1f464 dce9fbe 9c1f464 dce9fbe 9c1f464 dce9fbe 9c1f464 dce9fbe 08a94a9 dce9fbe 9c1f464 08a94a9 9c1f464 dce9fbe 08a94a9 dce9fbe 08a94a9 dce9fbe 08a94a9 dce9fbe 08a94a9 dce9fbe 08a94a9 0a04b24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import os
import torch
import gradio as gr
from PIL import Image
from torchvision import transforms
from timeit import default_timer as timer
from torch.nn import functional as F
from gradio.flagging import SimpleCSVLogger
torch.set_float32_matmul_precision("medium")
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")
torch.set_default_device(device=device)
# torch.autocast(enabled=True, dtype="float16", device_type="cuda")
TEST_TRANSFORMS = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class_labels = [
"Beagle",
"Boxer",
"Bulldog",
"Dachshund",
"German_Shepherd",
"Golden_Retriever",
"Labrador_Retriever",
"Poodle",
"Rottweiler",
"Yorkshire_Terrier",
]
# Model
model:torch.nn.Module = torch.jit.load("best_model.pt", map_location=device).to(device)
@torch.no_grad()
def predict_fn(img: Image):
start_time = timer()
try:
# img = np.array(img)
# print(img)
img = TEST_TRANSFORMS(img).to(device)
# print(type(img),img.shape)
logits = model(img.unsqueeze(0))
probabilities = F.softmax(logits, dim=-1)
# print(torch.topk(probabilities,k=2))
y_pred = probabilities.argmax(dim=-1).item()
confidence = probabilities[0][y_pred].item()
predicted_label = class_labels[y_pred]
# print(confidence,predicted_label)
pred_time = round(timer() - start_time, 5)
res = {f"Title: {predicted_label}": confidence}
return (res, pred_time)
except Exception as e:
print(f"error:: {e}")
gr.Error("An error occured 💥!", duration=5)
return ({"Title ☠️": 0.0}, 0.0)
gr.Interface(
fn=predict_fn,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=1, label="Predictions"), # what are the outputs?
gr.Number(label="Prediction time (s)"),
],
examples=[
["examples/" + i]
for i in os.listdir(os.path.join(os.path.dirname(__file__), "examples"))
],
title="Dog Breeds Classifier 🐈",
description="CNN-based Architecture for Fast and Accurate DogsBreed Classifier",
article="Created by muthukamalan.m ❤️",
cache_examples=True,
flagging_options=[],
flagging_callback=SimpleCSVLogger()
).launch(share=False, debug=False,server_name="0.0.0.0",server_port=7860,enable_monitoring=None)
|