File size: 1,295 Bytes
91dba4d
99ad741
5288696
 
6de72ef
 
5288696
 
 
f50f2ed
 
 
5288696
f50f2ed
 
5288696
 
 
 
 
f50f2ed
 
 
 
 
 
 
 
 
 
 
 
 
91dba4d
 
5288696
6de72ef
 
 
5288696
6de72ef
f50f2ed
6de72ef
91dba4d
 
 
6de72ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gradio as gr
import torch
from minicons import cwe
import pandas as pd
import numpy as np

from model import FeatureNormPredictor


def predict (Word, Sentence, LM, Layer, Norm):
    if Word not in Sentence: return "invalid input: word not in sentence"
    model_name = LM + str(Layer) + '_to_' + Norm
    lm = cwe.CWE('bert-base-uncased')
    if Layer not in range (lm.layers): return "invalid input: layer not in lm"

    model = FeatureNormPredictor.load_from_checkpoint(
        checkpoint_path=model_name+'.ckpt',
        map_location=None
        )
    model.eval()

    with open (model_name+'.txt', "r") as file:
        labels = [line.rstrip() for line in file.readlines()]

    data = (Word, Sentence)
    embs = lm.extract_representation(data, layer=8)
    avg = embs.sum(0)/len(data)
    pred = torch.nn.functional.relu(model(avg))
    pred = pred.squeeze(0)
    pred_list = pred.detach().numpy().tolist()
    
    output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if i > 0.0]
    return "\n".join(output)

demo = gr.Interface(
    fn=predict,
    inputs=[
        "text", 
        "text", 
        gr.Radio(["bert", "roberta", "electra"]),
        "number",
        gr.Radio(["binder", "mcrae", "buchanan"]),
    ],
    outputs=["text"],
)

demo.launch()