File size: 1,706 Bytes
91dba4d
99ad741
5288696
2ffd102
 
6de72ef
541d16e
5288696
 
1cd9844
 
 
f50f2ed
1cd9844
 
f50f2ed
1cd9844
 
2ffd102
541d16e
2ffd102
541d16e
 
 
 
2ffd102
541d16e
 
 
 
 
 
 
f50f2ed
541d16e
 
91dba4d
 
5288696
6de72ef
 
 
1cd9844
6de72ef
91dba4d
 
 
1cd9844
541d16e
 
5eb0088
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import gradio as gr
import torch
from minicons import cwe
from huggingface_hub import hf_hub_download
import os

from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams


def predict (Word, Sentence, modelname):
    models = {'Bert Layer 8 to Binder': ('bert-base-uncased', 'bert8_to_binder'),
              'Albert Layer 8 to Binder': ('albert-xxlarge-v2', 'albert8_to_binder_opt_stop')}
    if Word not in Sentence: return "invalid input: word not in sentence"
    model_name = models[modelname][1]
    lm = cwe.CWE(models[modelname][0])

    model_path = hf_hub_download("jwalanthi/semantic-feature-classifiers", model_name+".ckpt", use_auth_token=os.environ['TOKEN'])
    label_path = hf_hub_download("jwalanthi/semantic-feature-classifiers", model_name+".txt", use_auth_token=os.environ['TOKEN'])

    model = FeatureNormPredictor.load_from_checkpoint(
        checkpoint_path=model_path,
        map_location=None
    )
    model.eval()

    with open (label_path, "r") as file:
        labels = [line.rstrip() for line in file.readlines()]

    data = (Sentence, Word)
    emb = lm.extract_representation(data, layer=8)
    pred = torch.nn.functional.relu(model(emb))
    pred = pred.squeeze(0)
    pred_list = pred.detach().numpy().tolist()
    
    output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if pred_list[i] > 0.0]
    return "All Positive Predicted Values:\n"+"\n".join(output)

demo = gr.Interface(
    fn=predict,
    inputs=[
        "text", 
        "text", 
        gr.Radio(["Bert Layer 8 to Binder", "Albert Layer 8 to Binder"])
    ],
    outputs=["text"],
)

demo.launch(share=True)

if __name__ == "__main__":
    demo.launch(share=True)