import gradio as gr import torch from minicons import cwe import pandas as pd import numpy as np from model import FeatureNormPredictor def predict (Word, Sentence, LM, Layer, Norm): if Word not in Sentence: return "invalid input: word not in sentence" model_name = LM + str(Layer) + '_to_' + Norm lm = cwe.CWE('bert-base-uncased') if Layer not in range (lm.layers): return "invalid input: layer not in lm" model = FeatureNormPredictor.load_from_checkpoint( checkpoint_path=model_name+'.ckpt', map_location=None ) model.eval() with open (model_name+'.txt', "r") as file: labels = [line.rstrip() for line in file.readlines()] data = (Word, Sentence) embs = lm.extract_representation(data, layer=8) avg = embs.sum(0)/len(data) pred = torch.nn.functional.relu(model(avg)) pred = pred.squeeze(0) pred_list = pred.detach().numpy().tolist() output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if i > 0.0] return "\n".join(output) demo = gr.Interface( fn=predict, inputs=[ "text", "text", gr.Radio(["bert", "roberta", "electra"]), "number", gr.Radio(["binder", "mcrae", "buchanan"]), ], outputs=["text"], ) demo.launch()