Spaces:
Runtime error
Runtime error
File size: 5,121 Bytes
65eef23 93798d6 89e27db 8f3eda5 9b78f9c 8f3eda5 e56055d 8f3eda5 93798d6 89e27db 902d725 8f3eda5 93798d6 9b78f9c 9f3c7b7 b3740e7 89e27db bb4bb43 9b78f9c bb4bb43 9b78f9c bb4bb43 93798d6 b3740e7 93798d6 902d725 93798d6 902d725 9b78f9c 902d725 9b78f9c bb4bb43 9b78f9c bb4bb43 902d725 9db84b2 902d725 93798d6 db7fef9 93798d6 b3740e7 93798d6 db7fef9 93798d6 db7fef9 93798d6 9f3c7b7 93798d6 9f3c7b7 b3740e7 db7fef9 93798d6 db7fef9 93798d6 42f5f04 93798d6 b3740e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import streamlit as st
import pandas as pd
import time
import importlib
from torch.cuda import is_available as use_cuda
import algs
import config
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
EDIT_ALGS = [
"MEND: Model editor networks using gradient decomposition",
"SERAC: Semi-parametric editing with a retrieval-augmented counterfactual model",
"ENN: Editable neural networks",
"KE: KnowledgeEditor",
"FT: Fine-tuning",
"LU: Lookup Cache",
]
def generate(ids):
output_ids = st.session_state.editable_model.generate(input_ids=ids, max_new_tokens=20, min_length=1,
num_return_sequences=1, num_beams=3)
return st.session_state.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
def reset():
st.session_state.edits.drop(st.session_state.edits.index, inplace=True)
st.session_state.model_outputs.drop(st.session_state.edits.index, inplace=True)
selected_alg = st.session_state.alg_selector
selected_alg_idx = EDIT_ALGS.index(selected_alg)
with st.spinner('Loading model...'):
alg_abbrv = selected_alg[:selected_alg.index(":")]
alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
alg_class = getattr(alg_module, alg_abbrv.upper())
st.session_state.config = getattr(config, f"{alg_abbrv.lower()}_config")
st.session_state.editable_model = alg_class(
st.session_state.model,
st.session_state.config,
lambda: copy.deepcopy(st.session_state.model),
).eval()
def apply_edit():
st.session_state.edits.loc[len(st.session_state.edits)] = [str(edit_input), str(edit_label)]
############# Actually do the edit to the model
def sample_model():
input_str = str(test_input)
with st.spinner('Generating completion...'):
encoding = st.session_state.tokenizer(input_str, return_tensors="pt")
ids = encoding["input_ids"].to(st.session_state.device)
model_output = generate(ids)
n_edits = len(st.session_state.edits)
alg_name = st.session_state.alg_selector
alg_abbrv = alg_name[:alg_name.index(":")]
st.session_state.model_outputs.loc[len(st.session_state.model_outputs)] = [input_str, model_output, n_edits, alg_abbrv]
################################
#### Backend initialization ####
################################
if "init" not in st.session_state:
st.session_state.edits = pd.DataFrame([], columns=["Edit input", "Edit label"])
st.session_state.model_outputs = pd.DataFrame([], columns=["Input", "Output", "N edits", "Alg"])
st.session_state.init = True
st.session_state.config = None
st.session_state.device = "cuda" if use_cuda() else "cpu"
with st.spinner('Loading model...'):
st.session_state.tokenizer = AutoTokenizer.from_pretrained("google/t5-large-ssm-nq")
st.session_state.model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-large-ssm-nq").to(st.session_state.device).eval()
st.session_state.editable_model = None
########################
#### Interface code ####
########################
st.title("Language Model Editing")
st.markdown("**Note: this HF space is currently under development and doesn't actually work yet!**")
st.markdown("The goal of this demo is to give you a sense of the *abilities* and *limitations* of existing methods for **editing** pre-trained language models. **Model editing** algorithms use a single input-output pair to update a pre-trained model's behavior for that input (and ideally, related inputs).")
st.markdown("This demo uses a [T5-large](https://huggingface.co/google/t5-large-ssm-nq) model fine-tuned on [Natural Questions](https://arxiv.org/pdf/2002.08910.pdf) as the base pre-trained model.")
st.write("You can choose from a variety of algorithms for model editing in the dropdown below. At the bottom of the page, you can query the model for whatever input you want before/after editing.")
st.markdown("***")
col1, col2 = st.columns([5,1])
with col1:
alg_selector = st.selectbox("Editing algorithm:", EDIT_ALGS, key="alg_selector", on_change=reset)
with col2:
st.text("ㅤ")
st.button("Clear edits", on_click=reset)
st.write("Edits applied so far:")
st.table(st.session_state.edits)
col1, col2, col3 = st.columns([3, 2, 1])
with col1:
edit_input = st.text_input("Edit input:", placeholder="e.g., 'What is the tallest mountain on Earth?'")
with col2:
edit_label = st.text_input("Edit target:", placeholder="e.g., 'Denali'")
with col3:
st.text("ㅤ")
edit_button = st.button("Apply edit", on_click=apply_edit)
st.markdown("***")
if len(st.session_state.edits) == 0:
title = "Input to sample from *unedited* model:"
else:
title = f"Input to sample from *edited* model:"
col1, col2 = st.columns([5, 1])
with col1:
test_input = st.text_input(title, placeholder="e.g., 'What is the earth's tallest mountain?'")
with col2:
st.text("ㅤ")
generate_button = st.button("Generate", on_click=sample_model)
st.write("Model generation history:")
st.table(st.session_state.model_outputs) |