import gradio as gr import torch from minicons import cwe import pandas as pd import numpy as np from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams def predict (Word, Sentence, LM, Layer, Norm): models = {'bert': 'bert-base-uncased'} if Word not in Sentence: return "invalid input: word not in sentence" model_name = LM + str(Layer) + '_to_' + Norm lm = cwe.CWE(models[LM]) if Layer not in range (lm.layers): return "invalid input: layer not in lm" # labels = "These are some fake features".split(" ") # vals = np.random.randint(-10,10,(5)) # return model_name+" \n"+"\n".join([labels[i]+" "+str(vals[i]) for i in range(len(labels)) if vals[i]>0]) model = FeatureNormPredictor.load_from_checkpoint( checkpoint_path='models/'+model_name+'.ckpt', map_location=None ) model.eval() with open ('models/'+model_name+'.txt', "r") as file: labels = [line.rstrip() for line in file.readlines()] data = (Sentence, Word) emb = lm.extract_representation(data, layer=8) pred = torch.nn.functional.relu(model(emb)) pred = pred.squeeze(0) pred_list = pred.detach().numpy().tolist() output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if pred_list[i] > 0.0] return "All Positive Predicted Values:\n"+"\n".join(output) demo = gr.Interface( fn=predict, inputs=[ "text", "text", gr.Radio(["bert", "albert"]), "number", gr.Radio(["binder", "mcrae", "buchanan"]), ], outputs=["text"], ) demo.launch() if __name__ == "__main__": demo.launch()