sanjid's picture
aopp
3b5af6e
raw
history blame
1.06 kB
import gradio as gr
from transformers import AutoTokenizer
import torch
from fastai.text.all import *
from blurr.text.data.all import *
from blurr.text.modeling.all import *
# Define the path to your model and dataloaders
model_path = "origin-classifier-stage-2.pkl"
dls_path = "dls_origin-classifier_v1.pkl"
learner_inf = load_learner(model_path)
dls = torch.load(dls_path)
class_label_mapping = {label: idx for idx, label in enumerate(learner_inf.dls.vocab)}
def predict_text(text):
prediction = learner_inf.blurr_predict(text)
predicted_probs = prediction[0]['scores']
predicted_labels = prediction[0]['class_labels']
result = {label: f"{prob*100:.2f}%" for label, prob in zip(predicted_labels, predicted_probs)}
return result
iface = gr.Interface(
fn=predict_text,
inputs=gr.inputs.Textbox(lines=2, placeholder='Enter Recipe Here...'),
outputs=gr.outputs.Label(num_top_classes=3),
title="Food Origin Classification App",
description="Enter a Recipe, and it will predict the class label.",
)
iface.launch()