FoodVisionMini / app.py
Th3BossC's picture
Update app.py
fe27419
import gradio as gr
import torch
import os
import torch.nn as nn
import torchvision
from model import create_model
from pathlib import Path
from timeit import default_timer as timer
from typing import Tuple, Dict
#################
title = 'FoodVisionMini - Diljith'
description = 'Classifies an image of food item into either one of 3 classes : pizza, steak or sushi'
effnetb2, transforms, optimizer, lossFunc = create_model(num_classes = 3)
effnetb2.load_state_dict(torch.load(f = 'effnet_b2-20%-10epochs.pth', map_location = torch.device('cpu')))
class_names = ['pizza', 'steak', 'sushi']
def predict(img):
img = transforms(img).unsqueeze(0)
effnetb2.eval()
probs_dict = {}
start = timer()
with torch.inference_mode():
pred_probs = effnetb2(img).softmax(dim = 1)
pred_label = pred_probs.argmax(dim = 1)
probs_dict = {class_names[i] : float(pred_probs[0][i]) for i in range(len(class_names))}
end = timer()
return probs_dict, end-start
examples_path = Path('examples/')
example_list = [['examples/' + example] for example in os.listdir(examples_path)]
projectApp = gr.Interface(
fn = predict,
inputs = gr.Image(type = 'pil'),
outputs = [gr.Label(num_top_classes = len(class_names), label = 'Predictions'),
gr.Number(label = 'Prediction time(s)')],
examples = example_list,
title = title,
description = description
)
projectApp.launch(debug = False)