File size: 1,564 Bytes
91dba4d
99ad741
5288696
 
6de72ef
 
5288696
 
 
f50f2ed
 
 
5288696
f50f2ed
 
4be65da
 
ab801fa
4be65da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f50f2ed
4be65da
 
91dba4d
 
5288696
6de72ef
 
 
4be65da
6de72ef
f50f2ed
6de72ef
91dba4d
 
 
6de72ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import torch
from minicons import cwe
import pandas as pd
import numpy as np

from model import FeatureNormPredictor


def predict (Word, Sentence, LM, Layer, Norm):
    if Word not in Sentence: return "invalid input: word not in sentence"
    model_name = LM + str(Layer) + '_to_' + Norm
    lm = cwe.CWE('bert-base-uncased')
    if Layer not in range (lm.layers): return "invalid input: layer not in lm"

    labels = "These are some fake features".split(" ")
    vals = np.random.randint(-10,10,(5))
    return model_name+" \n"+"\n".join([labels[i]+" "+str(vals[i]) for i in range(len(labels)) if vals[i]>0])
    # model = FeatureNormPredictor.load_from_checkpoint(
    #     checkpoint_path=model_name+'.ckpt',
    #     map_location=None
    #     )
    # model.eval()

    # with open (model_name+'.txt', "r") as file:
    #     labels = [line.rstrip() for line in file.readlines()]

    # data = (Word, Sentence)
    # embs = lm.extract_representation(data, layer=8)
    # avg = embs.sum(0)/len(data)
    # pred = torch.nn.functional.relu(model(avg))
    # pred = pred.squeeze(0)
    # pred_list = pred.detach().numpy().tolist()
    
    # output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if pred_list[i] > 0.0]
    # return "All Positive Predicted Values:\n"+"\n".join(output)

demo = gr.Interface(
    fn=predict,
    inputs=[
        "text", 
        "text", 
        gr.Radio(["bert", "albert"]),
        "number",
        gr.Radio(["binder", "mcrae", "buchanan"]),
    ],
    outputs=["text"],
)

demo.launch()