File size: 1,774 Bytes
c4543eb 481c6b3 db2fdad 481c6b3 db2fdad 5f726f0 3c6e00e 99ad178 3c6e00e e5ffa90 db2fdad e5ffa90 64909f7 e5ffa90 b3b3581 481c6b3 6c86820 e5ffa90 6c86820 e5ffa90 6c86820 245cab5 6c86820 e5ffa90 6c86820 db2fdad ba0c1c3 db2fdad e5ffa90 db2fdad e5ffa90 d481ecd e42c394 b3b3581 e42c394 b3b3581 e42c394 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import torch
from utils import label_full_decoder
import sys
import dataset
import engine
from model import BERTBaseUncased
# from tokenizer import tokenizer
import config
from transformers import pipeline, AutoTokenizer, AutoModel
import gradio as gr
device = config.device
model = BERTBaseUncased()
model.load_state_dict(torch.load(
config.MODEL_PATH, map_location=torch.device(device)))
model.to(device)
# T = tokenizer.TweetTokenizer(
# preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)
# def preprocess(text):
# tokens = T.tokenize(text)
# print(tokens, file=sys.stderr)
# ptokens = []
# for index, token in enumerate(tokens):
# if "@" in token:
# if index > 0:
# # check if previous token was mention
# if "@" in tokens[index-1]:
# pass
# else:
# ptokens.append("mention_0")
# else:
# ptokens.append("mention_0")
# else:
# ptokens.append(token)
# print(ptokens, file=sys.stderr)
# return " ".join(ptokens)
def sentence_prediction(sentence):
# sentence = preprocess(sentence)
model_path = config.MODEL_PATH
test_dataset = dataset.BERTDataset(
review=[sentence],
target=[0]
)
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=config.VALID_BATCH_SIZE,
num_workers=-1
)
outputs, [] = engine.predict_fn(test_data_loader, model, device)
outputs = classifier(sentence)
print(outputs)
return outputs #{"label":outputs[0]}
demo = gr.Interface(
fn=sentence_prediction,
inputs='text',
outputs='label',
)
demo.launch()
|