import gradio as gr import torch from minicons import cwe from huggingface_hub import hf_hub_download import os from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams def predict (Word, Sentence, LM, Layer, Norm): models = {'bert': 'bert-base-uncased'} if Word not in Sentence: return "invalid input: word not in sentence" model_name = LM + str(Layer) + '_to_' + Norm lm = cwe.CWE(models[LM]) if Layer not in range (lm.layers): return "invalid input: layer not in lm" model_path = hf_hub_download("jwalanthi/bert_layer8_to_binder", model_name+".ckpt", use_auth_token=os.environ['TOKEN']) label_path = hf_hub_download("jwalanthi/bert_layer8_to_binder", model_name+".txt", use_auth_token=os.environ['TOKEN']) # labels = "These are some fake features".split(" ") # vals = np.random.randint(-10,10,(5)) # return model_name+" \n"+"\n".join([labels[i]+" "+str(vals[i]) for i in range(len(labels)) if vals[i]>0]) model = FeatureNormPredictor.load_from_checkpoint( checkpoint_path=model_path, map_location=None ) model.eval() with open (label_path, "r") as file: labels = [line.rstrip() for line in file.readlines()] data = (Sentence, Word) emb = lm.extract_representation(data, layer=8) pred = torch.nn.functional.relu(model(emb)) pred = pred.squeeze(0) pred_list = pred.detach().numpy().tolist() output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if pred_list[i] > 0.0] return "All Positive Predicted Values:\n"+"\n".join(output) demo = gr.Interface( fn=predict, inputs=[ "text", "text", gr.Radio(["bert", "albert"]), "number", gr.Radio(["binder", "mcrae", "buchanan"]), ], outputs=["text"], ) demo.launch() if __name__ == "__main__": demo.launch()