Spaces:
Running
Running
import gradio as gr | |
import torch | |
from minicons import cwe | |
import pandas as pd | |
import numpy as np | |
from model import FeatureNormPredictor | |
def predict (Word, Sentence, LM, Layer, Norm): | |
if Word not in Sentence: return "invalid input: word not in sentence" | |
model_name = LM + str(Layer) + '_to_' + Norm | |
lm = cwe.CWE('bert-base-uncased') | |
if Layer not in range (lm.layers): return "invalid input: layer not in lm" | |
labels = "These are some fake features".split(" ") | |
vals = np.random.randint(-10,10,(5)) | |
return model_name+" \n"+"\n".join([labels[i]+" "+str(vals[i]) for i in range(len(labels)) if vals[i]>0]) | |
# model = FeatureNormPredictor.load_from_checkpoint( | |
# checkpoint_path=model_name+'.ckpt', | |
# map_location=None | |
# ) | |
# model.eval() | |
# with open (model_name+'.txt', "r") as file: | |
# labels = [line.rstrip() for line in file.readlines()] | |
# data = (Word, Sentence) | |
# embs = lm.extract_representation(data, layer=8) | |
# avg = embs.sum(0)/len(data) | |
# pred = torch.nn.functional.relu(model(avg)) | |
# pred = pred.squeeze(0) | |
# pred_list = pred.detach().numpy().tolist() | |
# output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if pred_list[i] > 0.0] | |
# return "All Positive Predicted Values:\n"+"\n".join(output) | |
demo = gr.Interface( | |
fn=predict, | |
inputs=[ | |
"text", | |
"text", | |
gr.Radio(["bert", "albert"]), | |
"number", | |
gr.Radio(["binder", "mcrae", "buchanan"]), | |
], | |
outputs=["text"], | |
) | |
demo.launch() |