FoodVisionMini / app.py
Th3BossC's picture
first commit
788e4aa
raw
history blame
1.39 kB
import gradio as gr
import torch
import os
import torch.nn as nn
import torchvision
from model import create_model
from timeit import default_timer as timer
from typing import Tuple, Dict
#################
title = 'FoodVisionMini - Diljith'
description = 'Classifies an image of food item into either one of 3 classes : pizza, steak or sushi'
effnetb2, transforms, optimizer, lossFunc = create_model(num_classes = 3)
effnetb2.load_state_dict(torch.load(f = 'effnet_b2-20%-10epochs.pth'), map_location = torch.device('cpu'))
class_names = ['pizza', 'steak', 'sushi']
def predict(img):
img = transforms(img).unsqueeze(0)
effnetb2.eval()
probs_dict = {}
start = timer()
with torch.inference_mode():
pred_probs = effnetb2(img).softmax(dim = 1)
pred_label = pred_probs.argmax(dim = 1)
probs_dict = {class_names[i] : float(pred_probs[0][i]) for i in range(len(class_names))}
end = timer
return probs_dict, end-start
example_list = ['examples/' + example for example in os.listdir(examples_path)]
projectApp = gr.Interface(
fn = predict,
inputs = gr.Image(type = 'pil'),
outputs = [gr.Label(num_top_classes = len(class_names), label = 'Predictions'),
gr.Number(label = 'Prediction time(s)')],
examples = example_list,
title = title,
description = description
)
projectApp.launch()