Spaces:
Runtime error
Runtime error
Charles Lin
commited on
Commit
·
8335d0c
1
Parent(s):
a9853a7
All algs except KE working.
Browse files
algs/lu.py
CHANGED
@@ -15,56 +15,45 @@ class LU(EditableModel):
|
|
15 |
def __init__(self, model, config, model_constructor, memory=None):
|
16 |
super().__init__(model, config, model_constructor)
|
17 |
|
|
|
|
|
18 |
self.memory = memory
|
19 |
|
20 |
-
def
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
if self.memory is not None:
|
29 |
-
for i, encoder_state in enumerate(encoder_states):
|
30 |
-
if "gpt2" in self.config.model.name.lower():
|
31 |
-
# NOTE: broken
|
32 |
-
memory_prefixes, memory_labels = self.memory
|
33 |
-
prefix_means = encoder_state.cumsum(0).detach() / torch.arange(1, encoder_state.shape[0] + 1, device=encoder_state.device).view(-1, 1)
|
34 |
-
dist_mat = (prefix_means.unsqueeze(1) - memory_prefixes.unsqueeze(0)).norm(2, dim=-1)
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
closest_v = memory_labels[closest_idx]
|
47 |
-
|
48 |
-
if closest_dist < self.config.lu.threshold:
|
49 |
-
output[i] = torch.zeros((1, kwargs['labels'].shape[1], output.shape[2]), device=output.device)
|
50 |
-
for j, idx in enumerate(closest_v):
|
51 |
-
if j >= output.shape[1]:
|
52 |
-
break
|
53 |
-
output[i, j, idx] = self.config.lu.onehot_logit
|
54 |
-
if "t5" not in self.config.model.name.lower():
|
55 |
-
# T5 does not shift targets in the loss
|
56 |
-
output[i] = output[i].roll(-1, -2)
|
57 |
-
else:
|
58 |
-
avg_encoder_state = encoder_state.detach().mean(0)
|
59 |
-
memory_keys, memory_labels = self.memory
|
60 |
-
dists = torch.norm(avg_encoder_state - memory_keys, dim=-1)
|
61 |
-
closest_dist = dists.min()
|
62 |
-
closest_idx = dists.argmin()
|
63 |
-
closest_v = memory_labels[closest_idx]
|
64 |
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
return output
|
69 |
|
70 |
def edit(self, batch, condition=None, detach_history=False):
|
@@ -77,14 +66,9 @@ class LU(EditableModel):
|
|
77 |
memory_keys = []
|
78 |
memory_labels = []
|
79 |
for encoder_state, label in zip(encoder_states, batch["labels"]):
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
memory = (avg_encoder_states, label[-10:])
|
84 |
-
else:
|
85 |
-
avg_encoder_state = encoder_state.detach().mean(0)
|
86 |
-
memory_keys.append(avg_encoder_state)
|
87 |
-
memory_labels.append(label)
|
88 |
|
89 |
memory = (torch.stack(memory_keys), torch.stack(memory_labels))
|
90 |
return LU(self.model.eval(), self.config, self.model_constructor, memory), {}
|
|
|
15 |
def __init__(self, model, config, model_constructor, memory=None):
|
16 |
super().__init__(model, config, model_constructor)
|
17 |
|
18 |
+
if "t5" not in self.config.model.name.lower():
|
19 |
+
raise NotImplementedError
|
20 |
self.memory = memory
|
21 |
|
22 |
+
def lookup_replace(self, output, encoder_states):
|
23 |
+
for i, encoder_state in enumerate(encoder_states):
|
24 |
+
avg_encoder_state = encoder_state.detach().mean(0)
|
25 |
+
memory_keys, memory_labels = self.memory
|
26 |
+
dists = torch.norm(avg_encoder_state - memory_keys, dim=-1)
|
27 |
+
closest_dist = dists.min()
|
28 |
+
closest_idx = dists.argmin()
|
29 |
+
closest_v = memory_labels[closest_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
if closest_dist < self.config.lu.threshold:
|
32 |
+
output[i] = torch.zeros((1, output.shape[1], output.shape[2]), device=output.device)
|
33 |
+
for j, idx in enumerate(closest_v):
|
34 |
+
if j >= output.shape[1]:
|
35 |
+
break
|
36 |
+
output[i, j, idx] = self.config.lu.onehot_logit
|
37 |
+
if "t5" not in self.config.model.name.lower():
|
38 |
+
# T5 does not shift targets in the loss
|
39 |
+
output[i] = output[i].roll(-1, -2)
|
40 |
+
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
+
def generate(self, *inputs, **kwargs):
|
43 |
+
model_output = self.model.generate(*inputs, **kwargs, output_hidden_states=True,
|
44 |
+
output_scores=True, return_dict_in_generate=True)
|
45 |
+
encoder_states = _last_encoder_state(model_output)
|
46 |
+
output = _logits(model_output)
|
47 |
+
if self.memory is not None:
|
48 |
+
output = self.lookup_replace(output, encoder_states)
|
49 |
+
return output.argmax(-1)
|
50 |
|
51 |
+
def forward(self, *inputs, **kwargs):
|
52 |
+
model_output = self.model(*inputs, **kwargs, output_hidden_states=True)
|
53 |
+
encoder_states = _last_encoder_state(model_output)
|
54 |
+
output = _logits(model_output)
|
55 |
+
if self.memory is not None:
|
56 |
+
output = self.lookup_replace(output, encoder_states)
|
57 |
return output
|
58 |
|
59 |
def edit(self, batch, condition=None, detach_history=False):
|
|
|
66 |
memory_keys = []
|
67 |
memory_labels = []
|
68 |
for encoder_state, label in zip(encoder_states, batch["labels"]):
|
69 |
+
avg_encoder_state = encoder_state.detach().mean(0)
|
70 |
+
memory_keys.append(avg_encoder_state)
|
71 |
+
memory_labels.append(label)
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
memory = (torch.stack(memory_keys), torch.stack(memory_labels))
|
74 |
return LU(self.model.eval(), self.config, self.model_constructor, memory), {}
|
app.py
CHANGED
@@ -8,6 +8,7 @@ from torch.cuda import is_available as use_cuda
|
|
8 |
import algs
|
9 |
import config
|
10 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
11 |
|
12 |
|
13 |
EDIT_ALGS = [
|
@@ -19,6 +20,26 @@ EDIT_ALGS = [
|
|
19 |
"LU: Lookup Cache",
|
20 |
]
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
def generate(ids):
|
23 |
output_ids = st.session_state.editable_model.generate(input_ids=ids, max_new_tokens=20, min_length=1,
|
24 |
num_return_sequences=1, num_beams=3)
|
@@ -30,15 +51,7 @@ def reset():
|
|
30 |
|
31 |
selected_alg = st.session_state.alg_selector
|
32 |
alg_abbrv = selected_alg[:selected_alg.index(":")]
|
33 |
-
|
34 |
-
alg_class = getattr(alg_module, alg_abbrv.upper())
|
35 |
-
st.session_state.config = getattr(config, f"{alg_abbrv.lower()}_config")
|
36 |
-
with st.spinner('Loading model...'):
|
37 |
-
st.session_state.editable_model = alg_class(
|
38 |
-
st.session_state.model,
|
39 |
-
st.session_state.config,
|
40 |
-
lambda: copy.deepcopy(st.session_state.model),
|
41 |
-
).eval()
|
42 |
|
43 |
def apply_edit():
|
44 |
st.session_state.edits.loc[len(st.session_state.edits)] = [str(edit_input), str(edit_label)]
|
@@ -67,12 +80,13 @@ if "init" not in st.session_state:
|
|
67 |
st.session_state.edits = pd.DataFrame([], columns=["Edit input", "Edit label"])
|
68 |
st.session_state.model_outputs = pd.DataFrame([], columns=["Input", "Output", "N edits", "Alg"])
|
69 |
st.session_state.init = True
|
70 |
-
st.session_state.
|
71 |
-
st.session_state.device = "cuda" if use_cuda() else "cpu"
|
72 |
with st.spinner('Loading model...'):
|
73 |
st.session_state.tokenizer = AutoTokenizer.from_pretrained("google/t5-large-ssm-nq")
|
74 |
st.session_state.model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-large-ssm-nq").to(st.session_state.device).eval()
|
75 |
-
|
|
|
|
|
76 |
|
77 |
########################
|
78 |
#### Interface code ####
|
|
|
8 |
import algs
|
9 |
import config
|
10 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
11 |
+
import utils
|
12 |
|
13 |
|
14 |
EDIT_ALGS = [
|
|
|
20 |
"LU: Lookup Cache",
|
21 |
]
|
22 |
|
23 |
+
def get_alg_class(alg_abbrv):
|
24 |
+
alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
|
25 |
+
alg_class = getattr(alg_module, alg_abbrv.upper())
|
26 |
+
return alg_class
|
27 |
+
|
28 |
+
def load_editable_model(alg_abbrv):
|
29 |
+
alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
|
30 |
+
alg_class = getattr(alg_module, alg_abbrv.upper())
|
31 |
+
st.session_state.config = getattr(config, f"{alg_abbrv.lower()}_config")
|
32 |
+
with st.spinner('Loading model...'):
|
33 |
+
st.session_state.editable_model = alg_class(
|
34 |
+
st.session_state.model,
|
35 |
+
st.session_state.config,
|
36 |
+
lambda: copy.deepcopy(st.session_state.model),
|
37 |
+
).eval()
|
38 |
+
if "archive" in st.session_state.config:
|
39 |
+
archive, st.session_state.config.archive = utils.load_archive(str(st.session_state.config.archive))
|
40 |
+
print(f"Loading archive from {st.session_state.config.archive}")
|
41 |
+
st.session_state.editable_model.load_state_dict(archive["model"])
|
42 |
+
|
43 |
def generate(ids):
|
44 |
output_ids = st.session_state.editable_model.generate(input_ids=ids, max_new_tokens=20, min_length=1,
|
45 |
num_return_sequences=1, num_beams=3)
|
|
|
51 |
|
52 |
selected_alg = st.session_state.alg_selector
|
53 |
alg_abbrv = selected_alg[:selected_alg.index(":")]
|
54 |
+
load_editable_model(alg_abbrv)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def apply_edit():
|
57 |
st.session_state.edits.loc[len(st.session_state.edits)] = [str(edit_input), str(edit_label)]
|
|
|
80 |
st.session_state.edits = pd.DataFrame([], columns=["Edit input", "Edit label"])
|
81 |
st.session_state.model_outputs = pd.DataFrame([], columns=["Input", "Output", "N edits", "Alg"])
|
82 |
st.session_state.init = True
|
83 |
+
st.session_state.device = "cpu" # "cuda" if use_cuda() else "cpu"
|
|
|
84 |
with st.spinner('Loading model...'):
|
85 |
st.session_state.tokenizer = AutoTokenizer.from_pretrained("google/t5-large-ssm-nq")
|
86 |
st.session_state.model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-large-ssm-nq").to(st.session_state.device).eval()
|
87 |
+
# There is a "Loading model..." spinner in load_editable_model
|
88 |
+
alg_abbrv = "MEND" # Default initial alg of dropdown selector
|
89 |
+
load_editable_model(alg_abbrv)
|
90 |
|
91 |
########################
|
92 |
#### Interface code ####
|
config.py
CHANGED
@@ -21,7 +21,7 @@ model_config = {
|
|
21 |
}
|
22 |
|
23 |
ft_config = OmegaConf.create({
|
24 |
-
"device": "
|
25 |
"edit_lr": 5e-6,
|
26 |
"train_base": False,
|
27 |
"grad_clip": 100,
|
@@ -43,7 +43,7 @@ ft_config = OmegaConf.create({
|
|
43 |
})
|
44 |
|
45 |
lu_config = OmegaConf.create({
|
46 |
-
"device": "
|
47 |
"lu": {
|
48 |
"threshold": 2.75,
|
49 |
"onehot_logit": 1,
|
@@ -52,14 +52,14 @@ lu_config = OmegaConf.create({
|
|
52 |
})
|
53 |
|
54 |
ke_config = OmegaConf.create({
|
55 |
-
"device": "
|
56 |
"train_base": False,
|
57 |
"lr": 1e-5,
|
58 |
"model": model_config,
|
59 |
})
|
60 |
|
61 |
enn_config = OmegaConf.create({
|
62 |
-
"device": "
|
63 |
"lr": 1e-5,
|
64 |
"edit_lr": 1e-2,
|
65 |
"lr_lr": 1e-3,
|
@@ -72,10 +72,11 @@ enn_config = OmegaConf.create({
|
|
72 |
"n_edit_steps": 1,
|
73 |
},
|
74 |
"model": model_config,
|
|
|
75 |
})
|
76 |
|
77 |
mend_config = OmegaConf.create({
|
78 |
-
"device": "
|
79 |
"lr": 1e-6,
|
80 |
"edit_lr": 1e-4,
|
81 |
"lr_lr": 1e-4,
|
@@ -99,10 +100,11 @@ mend_config = OmegaConf.create({
|
|
99 |
"descent": False,
|
100 |
},
|
101 |
"model": model_config,
|
|
|
102 |
})
|
103 |
|
104 |
serac_config = OmegaConf.create({
|
105 |
-
"device": "cuda" if use_cuda() else "cpu",
|
106 |
"lr": 1e-5,
|
107 |
"edit_lr": 1e-2,
|
108 |
"lr_lr": 0,
|
@@ -128,4 +130,5 @@ serac_config = OmegaConf.create({
|
|
128 |
"cache_embeds": True,
|
129 |
},
|
130 |
"model": model_config,
|
|
|
131 |
})
|
|
|
21 |
}
|
22 |
|
23 |
ft_config = OmegaConf.create({
|
24 |
+
"device": "cpu",
|
25 |
"edit_lr": 5e-6,
|
26 |
"train_base": False,
|
27 |
"grad_clip": 100,
|
|
|
43 |
})
|
44 |
|
45 |
lu_config = OmegaConf.create({
|
46 |
+
"device": "cpu",
|
47 |
"lu": {
|
48 |
"threshold": 2.75,
|
49 |
"onehot_logit": 1,
|
|
|
52 |
})
|
53 |
|
54 |
ke_config = OmegaConf.create({
|
55 |
+
"device": "cpu",
|
56 |
"train_base": False,
|
57 |
"lr": 1e-5,
|
58 |
"model": model_config,
|
59 |
})
|
60 |
|
61 |
enn_config = OmegaConf.create({
|
62 |
+
"device": "cpu",
|
63 |
"lr": 1e-5,
|
64 |
"edit_lr": 1e-2,
|
65 |
"lr_lr": 1e-3,
|
|
|
72 |
"n_edit_steps": 1,
|
73 |
},
|
74 |
"model": model_config,
|
75 |
+
"archive": 8684705655, # "/iris/u/clin/code/efk/outputs/2022-02-09_05-48-20_8684705655/models/t5-large-ssm-nq.2022-02-09_05-48-20_8684705655",
|
76 |
})
|
77 |
|
78 |
mend_config = OmegaConf.create({
|
79 |
+
"device": "cpu",
|
80 |
"lr": 1e-6,
|
81 |
"edit_lr": 1e-4,
|
82 |
"lr_lr": 1e-4,
|
|
|
100 |
"descent": False,
|
101 |
},
|
102 |
"model": model_config,
|
103 |
+
"archive": 5940349945, # "/iris/u/clin/code/efk/outputs/2022-02-09_11-47-28_5940349945/models/t5-large-ssm-nq.2022-02-09_11-47-28_5940349945",
|
104 |
})
|
105 |
|
106 |
serac_config = OmegaConf.create({
|
107 |
+
"device": "cpu", # "device": "cuda" if use_cuda() else "cpu",
|
108 |
"lr": 1e-5,
|
109 |
"edit_lr": 1e-2,
|
110 |
"lr_lr": 0,
|
|
|
130 |
"cache_embeds": True,
|
131 |
},
|
132 |
"model": model_config,
|
133 |
+
"archive": 4719776130, # "/iris/u/clin/code/efk/outputs/2022-02-09_14-05-56_4719776130/models/t5-large-ssm-nq.2022-02-09_14-05-56_4719776130",
|
134 |
})
|
utils.py
CHANGED
@@ -156,12 +156,18 @@ def safe_backward(loss, parameters, accumulate=1, allow_unused=False, backward=F
|
|
156 |
|
157 |
|
158 |
def _logits(x):
|
159 |
-
|
|
|
|
|
|
|
|
|
160 |
|
161 |
|
162 |
def _last_encoder_state(x):
|
163 |
if hasattr(x, "encoder_last_hidden_state"):
|
164 |
return x.encoder_last_hidden_state
|
|
|
|
|
165 |
else:
|
166 |
return x.hidden_states[-1]
|
167 |
|
|
|
156 |
|
157 |
|
158 |
def _logits(x):
|
159 |
+
if hasattr(x, "logits"):
|
160 |
+
return x.logits
|
161 |
+
elif hasattr(x, "scores"):
|
162 |
+
return torch.cat(x.scores).unsqueeze(0)
|
163 |
+
return x
|
164 |
|
165 |
|
166 |
def _last_encoder_state(x):
|
167 |
if hasattr(x, "encoder_last_hidden_state"):
|
168 |
return x.encoder_last_hidden_state
|
169 |
+
elif hasattr(x, "encoder_hidden_states"):
|
170 |
+
return x.encoder_hidden_states[-1]
|
171 |
else:
|
172 |
return x.hidden_states[-1]
|
173 |
|