Spaces:
Runtime error
Runtime error
frankaging
commited on
Commit
·
bddba98
1
Parent(s):
e3ab52c
o1 impl
Browse files
app.py
CHANGED
|
@@ -13,7 +13,7 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
| 13 |
login(token=HF_TOKEN)
|
| 14 |
|
| 15 |
MAX_MAX_NEW_TOKENS = 2048
|
| 16 |
-
DEFAULT_MAX_NEW_TOKENS =
|
| 17 |
MAX_INPUT_TOKEN_LENGTH = 4096
|
| 18 |
|
| 19 |
def load_jsonl(jsonl_path):
|
|
@@ -29,7 +29,8 @@ class Steer(pv.SourcelessIntervention):
|
|
| 29 |
def __init__(self, **kwargs):
|
| 30 |
super().__init__(**kwargs, keep_last_dim=True)
|
| 31 |
self.proj = torch.nn.Linear(
|
| 32 |
-
|
|
|
|
| 33 |
def forward(self, base, source=None, subspaces=None):
|
| 34 |
steering_vec = torch.tensor(subspaces["mag"]) * \
|
| 35 |
self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
|
|
@@ -82,9 +83,9 @@ def generate(
|
|
| 82 |
|
| 83 |
# build list of messages
|
| 84 |
messages = []
|
| 85 |
-
for user_msg, model_msg in recent_history:
|
| 86 |
-
|
| 87 |
-
|
| 88 |
messages.append({"role": "user", "content": message})
|
| 89 |
|
| 90 |
input_ids = torch.tensor([tokenizer.apply_chat_template(
|
|
@@ -101,7 +102,12 @@ def generate(
|
|
| 101 |
"unit_locations": None,
|
| 102 |
"max_new_tokens": max_new_tokens,
|
| 103 |
"intervene_on_prompt": True,
|
| 104 |
-
"subspaces": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
"streamer": streamer,
|
| 106 |
"do_sample": True
|
| 107 |
}
|
|
@@ -121,8 +127,14 @@ def filter_concepts(search_text: str):
|
|
| 121 |
return filtered[:500]
|
| 122 |
|
| 123 |
def add_concept_to_list(selected_concept, user_slider_val, current_list):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
if not selected_concept:
|
| 125 |
-
return current_list,
|
|
|
|
| 126 |
idx = concept_id_map[selected_concept]
|
| 127 |
internal_mag = user_slider_val * 50
|
| 128 |
new_entry = {
|
|
@@ -132,24 +144,18 @@ def add_concept_to_list(selected_concept, user_slider_val, current_list):
|
|
| 132 |
"internal_mag": internal_mag,
|
| 133 |
}
|
| 134 |
updated_list = current_list + [new_entry]
|
| 135 |
-
return (
|
| 136 |
-
updated_list,
|
| 137 |
-
_build_table_data(updated_list),
|
| 138 |
-
gr.update(choices=_build_remove_choices(updated_list))
|
| 139 |
-
)
|
| 140 |
|
| 141 |
def remove_concept_from_list(selected_text, current_list):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
if not selected_text:
|
| 143 |
-
return current_list,
|
| 144 |
updated_list = [x for x in current_list if x["text"] != selected_text]
|
| 145 |
-
return (
|
| 146 |
-
updated_list,
|
| 147 |
-
_build_table_data(updated_list),
|
| 148 |
-
gr.update(choices=_build_remove_choices(updated_list))
|
| 149 |
-
)
|
| 150 |
-
|
| 151 |
-
def _build_table_data(subspaces):
|
| 152 |
-
return [[x["text"], x["display_mag"]] for x in subspaces]
|
| 153 |
|
| 154 |
def _build_remove_choices(subspaces):
|
| 155 |
return [x["text"] for x in subspaces]
|
|
@@ -211,12 +217,23 @@ with gr.Blocks(css="style.css") as demo:
|
|
| 211 |
remove_button = gr.Button("Remove", variant="secondary")
|
| 212 |
|
| 213 |
# Wire up events
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
add_button.click(
|
| 216 |
add_concept_to_list,
|
| 217 |
[concept_dropdown, concept_magnitude, selected_subspaces],
|
| 218 |
[selected_subspaces, remove_dropdown]
|
| 219 |
)
|
|
|
|
|
|
|
|
|
|
| 220 |
remove_button.click(
|
| 221 |
remove_concept_from_list,
|
| 222 |
[remove_dropdown, selected_subspaces],
|
|
|
|
| 13 |
login(token=HF_TOKEN)
|
| 14 |
|
| 15 |
MAX_MAX_NEW_TOKENS = 2048
|
| 16 |
+
DEFAULT_MAX_NEW_TOKENS = 128 # smaller default to save memory
|
| 17 |
MAX_INPUT_TOKEN_LENGTH = 4096
|
| 18 |
|
| 19 |
def load_jsonl(jsonl_path):
|
|
|
|
| 29 |
def __init__(self, **kwargs):
|
| 30 |
super().__init__(**kwargs, keep_last_dim=True)
|
| 31 |
self.proj = torch.nn.Linear(
|
| 32 |
+
self.embed_dim, kwargs["latent_dim"], bias=False
|
| 33 |
+
)
|
| 34 |
def forward(self, base, source=None, subspaces=None):
|
| 35 |
steering_vec = torch.tensor(subspaces["mag"]) * \
|
| 36 |
self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
|
|
|
|
| 83 |
|
| 84 |
# build list of messages
|
| 85 |
messages = []
|
| 86 |
+
# for user_msg, model_msg in recent_history:
|
| 87 |
+
# messages.append({"role": "user", "content": user_msg})
|
| 88 |
+
# messages.append({"role": "model", "content": model_msg})
|
| 89 |
messages.append({"role": "user", "content": message})
|
| 90 |
|
| 91 |
input_ids = torch.tensor([tokenizer.apply_chat_template(
|
|
|
|
| 102 |
"unit_locations": None,
|
| 103 |
"max_new_tokens": max_new_tokens,
|
| 104 |
"intervene_on_prompt": True,
|
| 105 |
+
"subspaces": [
|
| 106 |
+
{
|
| 107 |
+
"idx": int(subspaces_list[0]["idx"]),
|
| 108 |
+
"mag": int(subspaces_list[0]["internal_mag"])
|
| 109 |
+
}
|
| 110 |
+
] if subspaces_list else [],
|
| 111 |
"streamer": streamer,
|
| 112 |
"do_sample": True
|
| 113 |
}
|
|
|
|
| 127 |
return filtered[:500]
|
| 128 |
|
| 129 |
def add_concept_to_list(selected_concept, user_slider_val, current_list):
|
| 130 |
+
"""
|
| 131 |
+
Return exactly 2 values:
|
| 132 |
+
1) The updated list of concepts (list of dicts).
|
| 133 |
+
2) A Gradio update for the removal dropdown’s choices.
|
| 134 |
+
"""
|
| 135 |
if not selected_concept:
|
| 136 |
+
return current_list, gr.update(choices=_build_remove_choices(current_list))
|
| 137 |
+
|
| 138 |
idx = concept_id_map[selected_concept]
|
| 139 |
internal_mag = user_slider_val * 50
|
| 140 |
new_entry = {
|
|
|
|
| 144 |
"internal_mag": internal_mag,
|
| 145 |
}
|
| 146 |
updated_list = current_list + [new_entry]
|
| 147 |
+
return updated_list, gr.update(choices=_build_remove_choices(updated_list))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
def remove_concept_from_list(selected_text, current_list):
|
| 150 |
+
"""
|
| 151 |
+
Return exactly 2 values:
|
| 152 |
+
1) The updated list of concepts (list of dicts).
|
| 153 |
+
2) A Gradio update for the removal dropdown’s choices.
|
| 154 |
+
"""
|
| 155 |
if not selected_text:
|
| 156 |
+
return current_list, gr.update(choices=_build_remove_choices(current_list))
|
| 157 |
updated_list = [x for x in current_list if x["text"] != selected_text]
|
| 158 |
+
return updated_list, gr.update(choices=_build_remove_choices(updated_list))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
def _build_remove_choices(subspaces):
|
| 161 |
return [x["text"] for x in subspaces]
|
|
|
|
| 217 |
remove_button = gr.Button("Remove", variant="secondary")
|
| 218 |
|
| 219 |
# Wire up events
|
| 220 |
+
# When the search box changes, update the concept dropdown choices:
|
| 221 |
+
search_box.change(
|
| 222 |
+
update_dropdown_choices,
|
| 223 |
+
[search_box],
|
| 224 |
+
[concept_dropdown]
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# When "Add Concept" is clicked, add the concept + magnitude to the list,
|
| 228 |
+
# and update the "Remove" dropdown choices.
|
| 229 |
add_button.click(
|
| 230 |
add_concept_to_list,
|
| 231 |
[concept_dropdown, concept_magnitude, selected_subspaces],
|
| 232 |
[selected_subspaces, remove_dropdown]
|
| 233 |
)
|
| 234 |
+
|
| 235 |
+
# When "Remove" is clicked, remove the selected concept from the list,
|
| 236 |
+
# and update the "Remove" dropdown choices.
|
| 237 |
remove_button.click(
|
| 238 |
remove_concept_from_list,
|
| 239 |
[remove_dropdown, selected_subspaces],
|