semantic-demo / app.py
jwalanthi's picture
prediction added
f50f2ed
raw
history blame
1.3 kB
import gradio as gr
import torch
from minicons import cwe
import pandas as pd
import numpy as np
from model import FeatureNormPredictor
def predict (Word, Sentence, LM, Layer, Norm):
if Word not in Sentence: return "invalid input: word not in sentence"
model_name = LM + str(Layer) + '_to_' + Norm
lm = cwe.CWE('bert-base-uncased')
if Layer not in range (lm.layers): return "invalid input: layer not in lm"
model = FeatureNormPredictor.load_from_checkpoint(
checkpoint_path=model_name+'.ckpt',
map_location=None
)
model.eval()
with open (model_name+'.txt', "r") as file:
labels = [line.rstrip() for line in file.readlines()]
data = (Word, Sentence)
embs = lm.extract_representation(data, layer=8)
avg = embs.sum(0)/len(data)
pred = torch.nn.functional.relu(model(avg))
pred = pred.squeeze(0)
pred_list = pred.detach().numpy().tolist()
output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if i > 0.0]
return "\n".join(output)
demo = gr.Interface(
fn=predict,
inputs=[
"text",
"text",
gr.Radio(["bert", "roberta", "electra"]),
"number",
gr.Radio(["binder", "mcrae", "buchanan"]),
],
outputs=["text"],
)
demo.launch()