semantic-demo / app.py
jwalanthi's picture
link public
5eb0088
raw
history blame
1.88 kB
import gradio as gr
import torch
from minicons import cwe
from huggingface_hub import hf_hub_download
import os
from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams
def predict (Word, Sentence, LM, Layer, Norm):
models = {'bert': 'bert-base-uncased'}
if Word not in Sentence: return "invalid input: word not in sentence"
model_name = LM + str(Layer) + '_to_' + Norm
lm = cwe.CWE(models[LM])
if Layer not in range (lm.layers): return "invalid input: layer not in lm"
model_path = hf_hub_download("jwalanthi/bert_layer8_to_binder", model_name+".ckpt", use_auth_token=os.environ['TOKEN'])
label_path = hf_hub_download("jwalanthi/bert_layer8_to_binder", model_name+".txt", use_auth_token=os.environ['TOKEN'])
# labels = "These are some fake features".split(" ")
# vals = np.random.randint(-10,10,(5))
# return model_name+" \n"+"\n".join([labels[i]+" "+str(vals[i]) for i in range(len(labels)) if vals[i]>0])
model = FeatureNormPredictor.load_from_checkpoint(
checkpoint_path=model_path,
map_location=None
)
model.eval()
with open (label_path, "r") as file:
labels = [line.rstrip() for line in file.readlines()]
data = (Sentence, Word)
emb = lm.extract_representation(data, layer=8)
pred = torch.nn.functional.relu(model(emb))
pred = pred.squeeze(0)
pred_list = pred.detach().numpy().tolist()
output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if pred_list[i] > 0.0]
return "All Positive Predicted Values:\n"+"\n".join(output)
demo = gr.Interface(
fn=predict,
inputs=[
"text",
"text",
gr.Radio(["bert", "albert"]),
"number",
gr.Radio(["binder", "mcrae", "buchanan"]),
],
outputs=["text"],
)
demo.launch()
if __name__ == "__main__":
demo.launch(share=True)