semantic-demo / app.py
jwalanthi's picture
it works locally!!
541d16e
raw
history blame
1.64 kB
import gradio as gr
import torch
from minicons import cwe
import pandas as pd
import numpy as np
from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams
def predict (Word, Sentence, LM, Layer, Norm):
models = {'bert': 'bert-base-uncased'}
if Word not in Sentence: return "invalid input: word not in sentence"
model_name = LM + str(Layer) + '_to_' + Norm
lm = cwe.CWE(models[LM])
if Layer not in range (lm.layers): return "invalid input: layer not in lm"
# labels = "These are some fake features".split(" ")
# vals = np.random.randint(-10,10,(5))
# return model_name+" \n"+"\n".join([labels[i]+" "+str(vals[i]) for i in range(len(labels)) if vals[i]>0])
model = FeatureNormPredictor.load_from_checkpoint(
checkpoint_path='models/'+model_name+'.ckpt',
map_location=None
)
model.eval()
with open ('models/'+model_name+'.txt', "r") as file:
labels = [line.rstrip() for line in file.readlines()]
data = (Sentence, Word)
emb = lm.extract_representation(data, layer=8)
pred = torch.nn.functional.relu(model(emb))
pred = pred.squeeze(0)
pred_list = pred.detach().numpy().tolist()
output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if pred_list[i] > 0.0]
return "All Positive Predicted Values:\n"+"\n".join(output)
demo = gr.Interface(
fn=predict,
inputs=[
"text",
"text",
gr.Radio(["bert", "albert"]),
"number",
gr.Radio(["binder", "mcrae", "buchanan"]),
],
outputs=["text"],
)
demo.launch()
if __name__ == "__main__":
demo.launch()