File size: 1,737 Bytes
c4543eb
481c6b3
 
 
 
 
623670e
481c6b3
db2fdad
5f726f0
 
3c6e00e
 
db7d8dd
3c6e00e
e5ffa90
623670e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5ffa90
 
 
623670e
481c6b3
6c86820
e5ffa90
6c86820
 
 
 
e5ffa90
6c86820
 
 
623670e
6c86820
e5ffa90
6c86820
db2fdad
ba0c1c3
db2fdad
e5ffa90
db2fdad
e5ffa90
d481ecd
e42c394
 
 
b3b3581
e42c394
b3b3581
e42c394
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from utils import label_full_decoder
import sys
import dataset
import engine
from model import BERTBaseUncased
from tokenizer import tokenizer
import config
from transformers import pipeline, AutoTokenizer, AutoModel
import gradio as gr

device = config.device
model = BERTBaseUncased()
model.load_state_dict(torch.load(config.MODEL_PATH, map_location=torch.device(device)),strict=False)
model.to(device)

T = tokenizer.TweetTokenizer(
    preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)

def preprocess(text):
    tokens = T.tokenize(text)
    print(tokens, file=sys.stderr)
    ptokens = []
    for index, token in enumerate(tokens):
        if "@" in token:
            if index > 0:
                # check if previous token was mention
                if "@" in tokens[index-1]:
                    pass
                else:
                    ptokens.append("mention_0")
            else:
                ptokens.append("mention_0")
        else:
            ptokens.append(token)

    print(ptokens, file=sys.stderr)
    return " ".join(ptokens)


def sentence_prediction(sentence):
    sentence = preprocess(sentence)
    
    model_path = config.MODEL_PATH

    test_dataset = dataset.BERTDataset(
        review=[sentence],
        target=[0]
    )

    test_data_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.VALID_BATCH_SIZE,
        num_workers=2
    )

    outputs, [] = engine.predict_fn(test_data_loader, model, device)

    outputs =  classifier(sentence)
    
    print(outputs)
    return outputs #{"label":outputs[0]}




demo = gr.Interface(
  fn=sentence_prediction, 
  inputs='text',
  outputs='label',
)

demo.launch()