Muthukamalan's picture
Upload folder using huggingface_hub
0a04b24 verified
raw
history blame
2.54 kB
import os
import torch
import gradio as gr
from PIL import Image
from torchvision import transforms
from timeit import default_timer as timer
from torch.nn import functional as F
from gradio.flagging import SimpleCSVLogger
torch.set_float32_matmul_precision("medium")
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")
torch.set_default_device(device=device)
# torch.autocast(enabled=True, dtype="float16", device_type="cuda")
TEST_TRANSFORMS = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class_labels = [
"Beagle",
"Boxer",
"Bulldog",
"Dachshund",
"German_Shepherd",
"Golden_Retriever",
"Labrador_Retriever",
"Poodle",
"Rottweiler",
"Yorkshire_Terrier",
]
# Model
model:torch.nn.Module = torch.jit.load("best_model.pt", map_location=device).to(device)
@torch.no_grad()
def predict_fn(img: Image):
start_time = timer()
try:
# img = np.array(img)
# print(img)
img = TEST_TRANSFORMS(img).to(device)
# print(type(img),img.shape)
logits = model(img.unsqueeze(0))
probabilities = F.softmax(logits, dim=-1)
# print(torch.topk(probabilities,k=2))
y_pred = probabilities.argmax(dim=-1).item()
confidence = probabilities[0][y_pred].item()
predicted_label = class_labels[y_pred]
# print(confidence,predicted_label)
pred_time = round(timer() - start_time, 5)
res = {f"Title: {predicted_label}": confidence}
return (res, pred_time)
except Exception as e:
print(f"error:: {e}")
gr.Error("An error occured 💥!", duration=5)
return ({"Title ☠️": 0.0}, 0.0)
gr.Interface(
fn=predict_fn,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=1, label="Predictions"), # what are the outputs?
gr.Number(label="Prediction time (s)"),
],
examples=[
["examples/" + i]
for i in os.listdir(os.path.join(os.path.dirname(__file__), "examples"))
],
title="Dog Breeds Classifier 🐈",
description="CNN-based Architecture for Fast and Accurate DogsBreed Classifier",
article="Created by muthukamalan.m ❤️",
cache_examples=True,
flagging_options=[],
flagging_callback=SimpleCSVLogger()
).launch(share=False, debug=False,server_name="0.0.0.0",server_port=7860,enable_monitoring=None)