File size: 1,641 Bytes
91dba4d
99ad741
5288696
 
6de72ef
 
541d16e
5288696
 
f50f2ed
541d16e
f50f2ed
 
541d16e
f50f2ed
 
541d16e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f50f2ed
541d16e
 
91dba4d
 
5288696
6de72ef
 
 
4be65da
6de72ef
f50f2ed
6de72ef
91dba4d
 
 
541d16e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gradio as gr
import torch
from minicons import cwe
import pandas as pd
import numpy as np

from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams


def predict (Word, Sentence, LM, Layer, Norm):
    models = {'bert': 'bert-base-uncased'}
    if Word not in Sentence: return "invalid input: word not in sentence"
    model_name = LM + str(Layer) + '_to_' + Norm
    lm = cwe.CWE(models[LM])
    if Layer not in range (lm.layers): return "invalid input: layer not in lm"

    # labels = "These are some fake features".split(" ")
    # vals = np.random.randint(-10,10,(5))
    # return model_name+" \n"+"\n".join([labels[i]+" "+str(vals[i]) for i in range(len(labels)) if vals[i]>0])
    model = FeatureNormPredictor.load_from_checkpoint(
        checkpoint_path='models/'+model_name+'.ckpt',
        map_location=None
    )
    model.eval()

    with open ('models/'+model_name+'.txt', "r") as file:
        labels = [line.rstrip() for line in file.readlines()]

    data = (Sentence, Word)
    emb = lm.extract_representation(data, layer=8)
    pred = torch.nn.functional.relu(model(emb))
    pred = pred.squeeze(0)
    pred_list = pred.detach().numpy().tolist()
    
    output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if pred_list[i] > 0.0]
    return "All Positive Predicted Values:\n"+"\n".join(output)

demo = gr.Interface(
    fn=predict,
    inputs=[
        "text", 
        "text", 
        gr.Radio(["bert", "albert"]),
        "number",
        gr.Radio(["binder", "mcrae", "buchanan"]),
    ],
    outputs=["text"],
)

demo.launch()

if __name__ == "__main__":
    demo.launch()