frankaging commited on
Commit
36edf66
·
1 Parent(s): 7962ddb
Files changed (2) hide show
  1. app.py +100 -27
  2. style.css +0 -19
app.py CHANGED
@@ -2,12 +2,13 @@ import os, json, random
2
  import torch
3
  import gradio as gr
4
  import spaces
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  from huggingface_hub import login, hf_hub_download
7
  import pyreft
8
  import pyvene as pv
9
  from threading import Thread
10
  from typing import Iterator
 
11
 
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
  login(token=HF_TOKEN)
@@ -16,6 +17,18 @@ MAX_MAX_NEW_TOKENS = 2048
16
  DEFAULT_MAX_NEW_TOKENS = 256 # smaller default to save memory
17
  MAX_INPUT_TOKEN_LENGTH = 4096
18
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def load_jsonl(jsonl_path):
20
  jsonl_data = []
21
  with open(jsonl_path, 'r') as f:
@@ -29,19 +42,44 @@ class Steer(pv.SourcelessIntervention):
29
  def __init__(self, **kwargs):
30
  super().__init__(**kwargs, keep_last_dim=True)
31
  self.proj = torch.nn.Linear(
32
- self.embed_dim, kwargs["latent_dim"], bias=False
33
- )
34
  def forward(self, base, source=None, subspaces=None):
35
- if subspaces is None:
36
  return base
37
- steering_vec = []
38
- avg_mag = sum(subspaces["mag"]) / len(subspaces["mag"])
39
- for idx, mag in zip(subspaces["idx"], subspaces["mag"]):
40
- steering_vec.append(self.proj.weight[idx].unsqueeze(dim=0))
41
- steering_vec = torch.cat(steering_vec, dim=0).mean(dim=0)
42
- steering_vec = avg_mag * steering_vec
 
 
 
 
 
 
43
  return base + steering_vec
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  # Check GPU
46
  if not torch.cuda.is_available():
47
  print("Warning: Running on CPU, may be slow.")
@@ -73,7 +111,23 @@ if torch.cuda.is_available():
73
  concept_id_map[item["concept"]] = concept_reindex
74
  concept_reindex += 1
75
 
76
- steer = Steer(embed_dim=params.shape[0], latent_dim=params.shape[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  steer.proj.weight.data = params.float()
78
 
79
  pv_model = pv.IntervenableModel({
@@ -117,8 +171,10 @@ def generate(
117
  "intervene_on_prompt": True,
118
  "subspaces": [
119
  {
120
- "idx": [int(sl["idx"]) for sl in subspaces_list],
121
- "mag": [int(sl["internal_mag"]) for sl in subspaces_list]
 
 
122
  }
123
  ] if subspaces_list else None,
124
  "streamer": streamer,
@@ -133,9 +189,6 @@ def generate(
133
  partial_text.append(token_str)
134
  yield "".join(partial_text)
135
 
136
- def _build_remove_choices(subspaces):
137
- return [f"(+{x['display_mag']:.1f}*) {x['text']}" for x in subspaces]
138
-
139
  def filter_concepts(search_text: str):
140
  if not search_text.strip():
141
  return concept_list[:500]
@@ -144,15 +197,21 @@ def filter_concepts(search_text: str):
144
 
145
  def add_concept_to_list(selected_concept, user_slider_val, current_list):
146
  if not selected_concept:
147
- return current_list, gr.update(choices=_build_remove_choices(current_list))
148
 
149
- idx = concept_id_map[selected_concept]
 
 
 
 
 
150
  internal_mag = user_slider_val * 50
151
  new_entry = {
152
  "text": selected_concept,
153
  "idx": idx,
154
  "display_mag": user_slider_val,
155
  "internal_mag": internal_mag,
 
156
  }
157
  # Add to the beginning of the list
158
  current_list = [new_entry]
@@ -160,16 +219,23 @@ def add_concept_to_list(selected_concept, user_slider_val, current_list):
160
 
161
  def update_dropdown_choices(search_text):
162
  filtered = filter_concepts(search_text)
163
- if not filtered:
164
- return gr.update(choices=[], value=None, interactive=True)
 
 
 
 
 
 
 
165
  # Automatically select the first matching concept
166
  return gr.update(
167
  choices=filtered,
168
  value=filtered[0], # Select the first match
169
- interactive=True
170
- )
171
 
172
- with gr.Blocks(fill_height=True) as demo:
173
  # Remove default subspaces
174
  selected_subspaces = gr.State([])
175
 
@@ -179,7 +245,7 @@ with gr.Blocks(fill_height=True) as demo:
179
  chat_interface = gr.ChatInterface(
180
  fn=generate,
181
  title="Chat with a Concept Steering Model",
182
- description="Steer responses by selecting concepts on the right ",
183
  type="messages",
184
  additional_inputs=[selected_subspaces],
185
  fill_height=True
@@ -188,7 +254,7 @@ with gr.Blocks(fill_height=True) as demo:
188
  # Right side: concept management
189
  with gr.Column(scale=4):
190
  gr.Markdown("## Steer Model Responses")
191
- gr.Markdown("Search and then select a concept to steer. The closest match will be automatically selected.")
192
  # Concept Search and Selection
193
  with gr.Group():
194
  search_box = gr.Textbox(
@@ -196,6 +262,7 @@ with gr.Blocks(fill_height=True) as demo:
196
  placeholder="Find concepts to steer the model (e.g. 'time travel')",
197
  lines=2,
198
  )
 
199
  concept_dropdown = gr.Dropdown(
200
  label="Select a concept to steer the model (Click to see more!)",
201
  interactive=True,
@@ -211,10 +278,10 @@ with gr.Blocks(fill_height=True) as demo:
211
 
212
  # Wire up events
213
  # When search box changes, update dropdown AND trigger concept selection
214
- search_box.change(
215
  update_dropdown_choices,
216
  [search_box],
217
- [concept_dropdown]
218
  ).then( # Chain the events to automatically add the concept
219
  add_concept_to_list,
220
  [concept_dropdown, concept_magnitude, selected_subspaces],
@@ -227,6 +294,12 @@ with gr.Blocks(fill_height=True) as demo:
227
  [selected_subspaces]
228
  )
229
 
 
 
 
 
 
 
230
  concept_magnitude.input(
231
  add_concept_to_list,
232
  [concept_dropdown, concept_magnitude, selected_subspaces],
 
2
  import torch
3
  import gradio as gr
4
  import spaces
5
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  from huggingface_hub import login, hf_hub_download
7
  import pyreft
8
  import pyvene as pv
9
  from threading import Thread
10
  from typing import Iterator
11
+ import torch.nn.functional as F
12
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
14
  login(token=HF_TOKEN)
 
17
  DEFAULT_MAX_NEW_TOKENS = 256 # smaller default to save memory
18
  MAX_INPUT_TOKEN_LENGTH = 4096
19
 
20
+ css = """
21
+ #alert-message textarea {
22
+ background-color: #e8f4ff;
23
+ border: 1px solid #cce5ff;
24
+ color: #084298;
25
+ font-size: 1.1em;
26
+ padding: 12px;
27
+ border-radius: 4px;
28
+ font-weight: 500;
29
+ }
30
+ """
31
+
32
  def load_jsonl(jsonl_path):
33
  jsonl_data = []
34
  with open(jsonl_path, 'r') as f:
 
42
  def __init__(self, **kwargs):
43
  super().__init__(**kwargs, keep_last_dim=True)
44
  self.proj = torch.nn.Linear(
45
+ self.embed_dim, kwargs["latent_dim"], bias=False)
46
+ self.subspace_generator = kwargs["subspace_generator"]
47
  def forward(self, base, source=None, subspaces=None):
48
+ if subspaces == None:
49
  return base
50
+ if subspaces["subspace_gen_inputs"] is not None:
51
+ # we call our subspace generator to generate the subspace on-the-fly.
52
+ raw_steering_vec = self.subspace_generator(
53
+ subspaces["subspace_gen_inputs"]["input_ids"],
54
+ subspaces["subspace_gen_inputs"]["attention_mask"],
55
+ )[0]
56
+ steering_vec = torch.tensor(subspaces["mag"]) * \
57
+ raw_steering_vec.unsqueeze(dim=0)
58
+ return base + steering_vec
59
+ else:
60
+ steering_vec = torch.tensor(subspaces["mag"]) * \
61
+ self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
62
  return base + steering_vec
63
 
64
+ class RegressionWrapper(torch.nn.Module):
65
+ def __init__(self, base_model, hidden_size, output_dim):
66
+ super().__init__()
67
+ self.base_model = base_model
68
+ self.regression_head = torch.nn.Linear(hidden_size, output_dim)
69
+
70
+ def forward(self, input_ids, attention_mask):
71
+ outputs = self.base_model.model(
72
+ input_ids=input_ids,
73
+ attention_mask=attention_mask,
74
+ output_hidden_states=True,
75
+ return_dict=True
76
+ )
77
+ last_hiddens = outputs.hidden_states[-1]
78
+ last_token_representations = last_hiddens[:, -1]
79
+ preds = self.regression_head(last_token_representations)
80
+ preds = F.normalize(preds, p=2, dim=-1)
81
+ return preds
82
+
83
  # Check GPU
84
  if not torch.cuda.is_available():
85
  print("Warning: Running on CPU, may be slow.")
 
111
  concept_id_map[item["concept"]] = concept_reindex
112
  concept_reindex += 1
113
 
114
+ # load subspace generator.
115
+ base_tokenizer = AutoTokenizer.from_pretrained(
116
+ f"google/gemma-2-2b", model_max_length=512)
117
+ config = AutoConfig.from_pretrained("google/gemma-2-2b")
118
+ base_model = AutoModelForCausalLM.from_config(config)
119
+
120
+ subspace_generator_weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res-generator", filename="l20/weight.pt")
121
+ hidden_size = base_model.config.hidden_size
122
+ subspace_generator = RegressionWrapper(
123
+ base_model, hidden_size, hidden_size).bfloat16().to("cuda")
124
+ subspace_generator.load_state_dict(torch.load(subspace_generator_weight_path))
125
+ print(f"Loading model from saved file {subspace_generator_weight_path}")
126
+ _ = subspace_generator.eval()
127
+
128
+ steer = Steer(
129
+ embed_dim=params.shape[0], latent_dim=params.shape[1],
130
+ subspace_generator=subspace_generator)
131
  steer.proj.weight.data = params.float()
132
 
133
  pv_model = pv.IntervenableModel({
 
171
  "intervene_on_prompt": True,
172
  "subspaces": [
173
  {
174
+ "idx": int(subspaces_list[0]["idx"]),
175
+ "mag": int(subspaces_list[0]["internal_mag"]),
176
+ "subspace_gen_inputs": base_tokenizer(subspaces_list[0]["subspace_gen_text"], return_tensors="pt").to("cuda") \
177
+ if subspaces_list[0]["subspace_gen_text"] is not None else None
178
  }
179
  ] if subspaces_list else None,
180
  "streamer": streamer,
 
189
  partial_text.append(token_str)
190
  yield "".join(partial_text)
191
 
 
 
 
192
  def filter_concepts(search_text: str):
193
  if not search_text.strip():
194
  return concept_list[:500]
 
197
 
198
  def add_concept_to_list(selected_concept, user_slider_val, current_list):
199
  if not selected_concept:
200
+ return current_list
201
 
202
+ selected_concept_text = None
203
+ if selected_concept.startswith("[New] "):
204
+ selected_concept_text = selected_concept[6:]
205
+ idx = 0
206
+ else:
207
+ idx = concept_id_map[selected_concept]
208
  internal_mag = user_slider_val * 50
209
  new_entry = {
210
  "text": selected_concept,
211
  "idx": idx,
212
  "display_mag": user_slider_val,
213
  "internal_mag": internal_mag,
214
+ "subspace_gen_text": selected_concept_text
215
  }
216
  # Add to the beginning of the list
217
  current_list = [new_entry]
 
219
 
220
  def update_dropdown_choices(search_text):
221
  filtered = filter_concepts(search_text)
222
+ if not filtered or len(filtered) == 0:
223
+ return gr.update(choices=[f"[New] {search_text}"], value=f"[New] {search_text}", interactive=True), gr.Textbox(
224
+ label="No matching existing concepts were found!",
225
+ value="Good news! Based on the concept you provided, we will automatically generate a steering vector. Try it out by starting a chat!",
226
+ lines=3,
227
+ interactive=False,
228
+ visible=True,
229
+ elem_id="alert-message"
230
+ )
231
  # Automatically select the first matching concept
232
  return gr.update(
233
  choices=filtered,
234
  value=filtered[0], # Select the first match
235
+ interactive=True, visible=True
236
+ ), gr.Textbox(visible=False)
237
 
238
+ with gr.Blocks(css=css, fill_height=True) as demo:
239
  # Remove default subspaces
240
  selected_subspaces = gr.State([])
241
 
 
245
  chat_interface = gr.ChatInterface(
246
  fn=generate,
247
  title="Chat with a Concept Steering Model",
248
+ description="""Steer responses by selecting concepts on the right →\n\nWe are using Gemma-2-2B-it with steering vectors added to the residual stream at layer 20. Our auto-steer steering vector generated is a finetuned Gemma-2-2B model.""",
249
  type="messages",
250
  additional_inputs=[selected_subspaces],
251
  fill_height=True
 
254
  # Right side: concept management
255
  with gr.Column(scale=4):
256
  gr.Markdown("## Steer Model Responses")
257
+ gr.Markdown("Search and then select a concept to steer. The closest match will be automatically selected. If there is no match, we will use our steering vector generator to auto-steer for you!")
258
  # Concept Search and Selection
259
  with gr.Group():
260
  search_box = gr.Textbox(
 
262
  placeholder="Find concepts to steer the model (e.g. 'time travel')",
263
  lines=2,
264
  )
265
+ msg = gr.TextArea(visible=False)
266
  concept_dropdown = gr.Dropdown(
267
  label="Select a concept to steer the model (Click to see more!)",
268
  interactive=True,
 
278
 
279
  # Wire up events
280
  # When search box changes, update dropdown AND trigger concept selection
281
+ search_box.input(
282
  update_dropdown_choices,
283
  [search_box],
284
+ [concept_dropdown, msg]
285
  ).then( # Chain the events to automatically add the concept
286
  add_concept_to_list,
287
  [concept_dropdown, concept_magnitude, selected_subspaces],
 
294
  [selected_subspaces]
295
  )
296
 
297
+ concept_dropdown.change(
298
+ add_concept_to_list,
299
+ [concept_dropdown, concept_magnitude, selected_subspaces],
300
+ [selected_subspaces]
301
+ )
302
+
303
  concept_magnitude.input(
304
  add_concept_to_list,
305
  [concept_dropdown, concept_magnitude, selected_subspaces],
style.css DELETED
@@ -1,19 +0,0 @@
1
- #alert-message label {
2
- font-weight: 700;
3
- background-color: #fff3cd;
4
- padding: 8px;
5
- border-radius: 4px;
6
- color: #664d03;
7
- display: inline-block;
8
- margin-bottom: 8px;
9
- }
10
-
11
- #alert-message textarea {
12
- background-color: #e8f4ff;
13
- border: 1px solid #cce5ff;
14
- color: #084298;
15
- font-size: 1.1em;
16
- padding: 12px;
17
- border-radius: 4px;
18
- font-weight: 500;
19
- }