Spaces:
Running
Running
File size: 1,706 Bytes
91dba4d 99ad741 5288696 2ffd102 6de72ef 541d16e 5288696 1cd9844 f50f2ed 1cd9844 f50f2ed 1cd9844 2ffd102 541d16e 2ffd102 541d16e 2ffd102 541d16e f50f2ed 541d16e 91dba4d 5288696 6de72ef 1cd9844 6de72ef 91dba4d 1cd9844 541d16e 5eb0088 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import gradio as gr
import torch
from minicons import cwe
from huggingface_hub import hf_hub_download
import os
from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams
def predict (Word, Sentence, modelname):
models = {'Bert Layer 8 to Binder': ('bert-base-uncased', 'bert8_to_binder'),
'Albert Layer 8 to Binder': ('albert-xxlarge-v2', 'albert8_to_binder_opt_stop')}
if Word not in Sentence: return "invalid input: word not in sentence"
model_name = models[modelname][1]
lm = cwe.CWE(models[modelname][0])
model_path = hf_hub_download("jwalanthi/semantic-feature-classifiers", model_name+".ckpt", use_auth_token=os.environ['TOKEN'])
label_path = hf_hub_download("jwalanthi/semantic-feature-classifiers", model_name+".txt", use_auth_token=os.environ['TOKEN'])
model = FeatureNormPredictor.load_from_checkpoint(
checkpoint_path=model_path,
map_location=None
)
model.eval()
with open (label_path, "r") as file:
labels = [line.rstrip() for line in file.readlines()]
data = (Sentence, Word)
emb = lm.extract_representation(data, layer=8)
pred = torch.nn.functional.relu(model(emb))
pred = pred.squeeze(0)
pred_list = pred.detach().numpy().tolist()
output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if pred_list[i] > 0.0]
return "All Positive Predicted Values:\n"+"\n".join(output)
demo = gr.Interface(
fn=predict,
inputs=[
"text",
"text",
gr.Radio(["Bert Layer 8 to Binder", "Albert Layer 8 to Binder"])
],
outputs=["text"],
)
demo.launch(share=True)
if __name__ == "__main__":
demo.launch(share=True) |