thak123's picture
Update app.py
5f726f0
raw
history blame
1.69 kB
import torch
import sys
import dataset
import engine
from model import BERTBaseUncased
from tokenizer import tokenizer
import config
import gradio as gr
T = tokenizer.TweetTokenizer(
preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)
def preprocess(text):
tokens = T.tokenize(text)
print(tokens, file=sys.stderr)
ptokens = []
for index, token in enumerate(tokens):
if "@" in token:
if index > 0:
# check if previous token was mention
if "@" in tokens[index-1]:
pass
else:
ptokens.append("mention_0")
else:
ptokens.append("mention_0")
else:
ptokens.append(token)
print(ptokens, file=sys.stderr)
return " ".join(ptokens)
def sentence_prediction(sentence):
sentence = preprocess(sentence)
model_path = config.MODEL_PATH
test_dataset = dataset.BERTDataset(
review=[sentence],
target=[0]
)
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=config.VALID_BATCH_SIZE,
num_workers=3
)
device = config.device
model = BERTBaseUncased()
model.load_state_dict(torch.load(
model_path, map_location=torch.device(device)))
model.to(device)
outputs, [] = engine.predict_fn(test_data_loader, model, device)
print(outputs)
return {"label":outputs[0]}
demo = gr.Interface(
fn=sentence_prediction,
inputs=gr.Textbox(placeholder="Enter a sentence here..."),
outputs="label",
# interpretation="default",
examples=[["!"]])
demo.launch()