Spaces:
Runtime error
Runtime error
File size: 2,915 Bytes
9f3c7b7 65eef23 93798d6 89e27db 93798d6 89e27db 93798d6 89e27db 93798d6 b3740e7 93798d6 b3740e7 93798d6 b3740e7 65eef23 9f3c7b7 b3740e7 89e27db 93798d6 b3740e7 93798d6 b3740e7 93798d6 db7fef9 93798d6 b3740e7 93798d6 db7fef9 93798d6 db7fef9 93798d6 9f3c7b7 93798d6 9f3c7b7 b3740e7 db7fef9 93798d6 db7fef9 93798d6 42f5f04 93798d6 b3740e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from turtle import onclick
import streamlit as st
import pandas as pd
import time
EDIT_ALGS = [
"MEND: Model editor networks using gradient decomposition",
"SERAC: Semi-parametric editing with a retrieval-augmented counterfactual model",
"ENN: Editable neural networks",
"KE: KnowledgeEditor",
"Fine-tuning",
"Lookup Cache"
]
st.title("Language Model Editing")
st.write("Choose an editing algorithm, apply some edits, and sample from the model to see how its behavior changes. You can sample the model at any time to see the \"before and after\" of the edits you apply.")
st.markdown("***")
# https://discuss.streamlit.io/t/simple-example-of-persistence-and-waiting-for-input/2111
@st.cache(allow_output_mutation=True)
def Edits():
return
@st.cache(allow_output_mutation=True)
def ModelOutput():
return
if "init" not in st.session_state:
# Perform first-time initialization
st.session_state.edits = pd.DataFrame([], columns=["Edit input", "Edit label"])
st.session_state.model_outputs = pd.DataFrame([], columns=["Input", "Generation", "Edits applied"])
st.session_state.init = True
def reset():
st.session_state.edits.drop(st.session_state.edits.index, inplace=True)
st.session_state.model_outputs.drop(st.session_state.edits.index, inplace=True)
selected_alg = st.session_state.alg_selector
selected_alg_idx = EDIT_ALGS.index(selected_alg)
############# Need to reset the model here (and maybe show progress spinner?)
def apply_edit():
st.session_state.edits.loc[len(st.session_state.edits)] = [str(edit_input), str(edit_label)]
def sample_model():
st.session_state.model_outputs.loc[len(st.session_state.model_outputs)] = [str(test_input), "blah blah blah", len(st.session_state.edits)]
col1, col2 = st.columns([5,1])
with col1:
alg_selector = st.selectbox("Editing algorithm:", EDIT_ALGS, key="alg_selector", on_change=reset)
with col2:
st.text("ㅤ")
st.button("Clear edits", on_click=reset)
st.write("Edits applied so far:")
st.table(st.session_state.edits)
st.markdown("***")
col1, col2, col3 = st.columns([3, 2, 1])
with col1:
edit_input = st.text_input("Edit input:", placeholder="e.g., 'What is the tallest mountain on Earth?'")
with col2:
edit_label = st.text_input("Edit target:", placeholder="e.g., 'Denali'")
with col3:
st.text("ㅤ")
edit_button = st.button("Apply edit", on_click=apply_edit)
st.markdown("***")
if len(st.session_state.edits) == 0:
title = "Input to sample from *unedited* model:"
else:
title = f"Input to sample from *edited* model:"
col1, col2 = st.columns([5, 1])
with col1:
test_input = st.text_input(title, placeholder="e.g., 'What is the earth's tallest mountain?'")
with col2:
st.text("ㅤ")
generate_button = st.button("Generate", on_click=sample_model)
st.write("Model generation history:")
st.table(st.session_state.model_outputs) |