import gradio as gr import torch from minicons import cwe import pandas as pd import numpy as np from model import FeatureNormPredictor def predict (Word, Sentence, LM, Layer, Norm): if Word not in Sentence: return "invalid input: word not in sentence" model_name = LM + str(Layer) + '_to_' + Norm lm = cwe.CWE('bert-base-uncased') if Layer not in range (lm.layers): return "invalid input: layer not in lm" labels = "These are some fake features".split(" ") vals = np.random.randint(-10,10,(5)) print(model_name+" \n"+"\n".join([labels[i]+" "+str(vals[i]) for i in range(len(labels)) if vals[i]>0])) # model = FeatureNormPredictor.load_from_checkpoint( # checkpoint_path=model_name+'.ckpt', # map_location=None # ) # model.eval() # with open (model_name+'.txt', "r") as file: # labels = [line.rstrip() for line in file.readlines()] # data = (Word, Sentence) # embs = lm.extract_representation(data, layer=8) # avg = embs.sum(0)/len(data) # pred = torch.nn.functional.relu(model(avg)) # pred = pred.squeeze(0) # pred_list = pred.detach().numpy().tolist() # output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if pred_list[i] > 0.0] # return "All Positive Predicted Values:\n"+"\n".join(output) demo = gr.Interface( fn=predict, inputs=[ "text", "text", gr.Radio(["bert", "albert"]), "number", gr.Radio(["binder", "mcrae", "buchanan"]), ], outputs=["text"], ) demo.launch()