Emoji-Predictor / app.py
ashish-001's picture
Update app.py
428bc7f verified
import gradio as gr
from huggingface_hub import hf_hub_download
import joblib
from transformers import AutoTokenizer
import numpy as np
from dotenv import load_dotenv
import os
import onnxruntime as ort
class EmojiPrediction:
def __init__(self):
self.model = None
self.tokenizer = None
self.label_map = None
self.thresholds = None
self.repo_id = "ashish-001/tweet-emoji-predictor"
self.load_model()
self.load_required_data()
def load_model(self):
self.tokenizer = AutoTokenizer.from_pretrained(
self.repo_id)
onnx_file = hf_hub_download(
repo_id=self.repo_id,
filename='model_quantized.onnx'
)
self.model = ort.InferenceSession(onnx_file)
def load_required_data(self):
filepath = hf_hub_download(
repo_id=self.repo_id,
filename='mlb_emoji_encoder.pkl'
)
with open(filepath, 'rb') as f:
self.label_map = joblib.load(f)
threshold_filepath = hf_hub_download(
repo_id=self.repo_id,
filename='thresholds.npy'
)
self.thresholds = np.load(threshold_filepath)
def predict_emoji(self, text):
if not len(text.strip()):
return "", ""
inputs = self.tokenizer(
text,
return_tensors="np"
)
# with torch.no_grad():
onnx_inputs = {
"input_ids": inputs["input_ids"].astype(np.int64),
"attention_mask": inputs["attention_mask"].astype(np.int64),
}
logits = self.model.run(None, onnx_inputs)[0]
probs = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
predicted_labels = (probs >= self.thresholds).astype(
int).reshape(1, -1)
emojis = "".join(self.label_map.inverse_transform(predicted_labels)[0])
return emojis, f"{text} {emojis}"
emojiprediction = EmojiPrediction()
with gr.Blocks() as app:
gr.Markdown("# Tweet/Text Emoji Predictor")
with gr.Row(equal_height=True):
textbox = gr.Textbox(lines=1, label="User Input",
placeholder="Start entering the text")
textbox1 = gr.Textbox(label="Raw Emoji Output")
textbox2 = gr.Textbox(label="Text with Emojis")
textbox.change(
fn=emojiprediction.predict_emoji,
inputs=textbox,
outputs=[textbox1, textbox2]
)
if __name__ == "__main__":
app.launch()