Spaces:
Running
Running
import gradio as gr | |
import torch | |
from minicons import cwe | |
from huggingface_hub import hf_hub_download | |
import os | |
from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams | |
def predict (Word, Sentence, LLM, Norm, Layer): | |
models = {'BERT': 'bert-base-uncased', | |
'ALBERT': 'albert-xxlarge-v2', | |
'roBERTa': 'roberta-base'} | |
if Word not in Sentence: return "invalid input: word not in sentence" | |
model_name_hf = LLM.lower() | |
norm_name_hf = Norm.lower() | |
lm = cwe.CWE(models[LLM]) | |
repo_id = "jwalanthi/semantic-feature-classifiers" | |
subfolder = "{model_name_hf}_models_all" | |
name_hf = f"{model_name_hf}_to_{norm_name_hf}_layer{Layer}" | |
model_path = hf_hub_download(repo_id = repo_id, subfolder=subfolder, filename=f"{name_hf}.ckpt", use_auth_token=os.environ['TOKEN']) | |
label_path = hf_hub_download(repo_id = repo_id, subfolder=subfolder, filename=f"{name_hf}.txt", use_auth_token=os.environ['TOKEN']) | |
model = FeatureNormPredictor.load_from_checkpoint( | |
checkpoint_path=model_path, | |
map_location=None | |
) | |
model.eval() | |
with open (label_path, "r") as file: | |
labels = [line.rstrip() for line in file.readlines()] | |
data = (Sentence, Word) | |
emb = lm.extract_representation(data, layer=8) | |
pred = torch.nn.functional.relu(model(emb)) | |
pred = pred.squeeze(0) | |
pred_list = pred.detach().numpy().tolist() | |
output = [labels[i]+'\t\t\t\t\t\t\t'+str(pred_list[i]) for i in range(len(labels)) if pred_list[i] > 0.0] | |
return "All Positive Predicted Values:\n"+"\n".join(output) | |
demo = gr.Interface( | |
fn=predict, | |
inputs=[ | |
"text", | |
"text", | |
gr.Radio(["BERT"]), | |
gr.Radio(["Binder", "McRae", "Buchanan"]), | |
gr.Slider(0,12, step=1) | |
], | |
outputs=["text"], | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
demo.launch() |