|
import gradio as gr |
|
from transformers import AutoTokenizer |
|
import torch |
|
from fastai.text.all import * |
|
from blurr.text.data.all import * |
|
from blurr.text.modeling.all import * |
|
|
|
|
|
model_path = "origin-classifier-stage-2.pkl" |
|
dls_path = "dls_origin-classifier_v1.pkl" |
|
|
|
|
|
learner_inf = load_learner(model_path) |
|
|
|
|
|
dls = torch.load(dls_path) |
|
|
|
|
|
class_label_mapping = {label: idx for idx, label in enumerate(learner_inf.dls.vocab)} |
|
|
|
|
|
def predict_text(text): |
|
prediction = learner_inf.blurr_predict(text)[0] |
|
predicted_probs = prediction['probs'] |
|
top_5_indices = predicted_probs.argsort(descending=True)[:5] |
|
top_5_labels = [list(class_label_mapping.keys())[list(class_label_mapping.values()).index(idx)] for idx in top_5_indices] |
|
return top_5_labels |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict_text, |
|
inputs="text", |
|
outputs=gr.outputs.Label(num_top_classes=5), |
|
title="Food Origin Classification App", |
|
description="Enter a Recipe, and it will predict the class label.", |
|
) |
|
|
|
|
|
iface.launch() |
|
|